summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-13 23:07:04 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-13 23:07:04 +0200
commit94938d648e021d2ace0f3b7bf383d256449d619f (patch)
treeef077de27b08027c7117761c3efc7d29b7ad3d56 /src/AST
parent3d8a6cca424fc5279c43a266900160feb28aa715 (diff)
WIP better zero/plus, fixing Accum (...)
The accumulator implementation was wrong because it forgot (in accumAdd) to take into account that values may be variably-sized. Furthermore, it was also complexity-inefficient because it did not build up a sparse value. Thus let's go for the Haskell-interpreter-equivalent of what a real, fast, compiled implementation would do: just a tree with mutable variables. In practice one can decide to indeed flatten parts of that tree, i.e. using a tree representation for nested pairs is bad, but that should have been done _before_ execution and for _all_ occurrences of that type fragment, not live at runtime by the accumulator implementation.
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs3
-rw-r--r--src/AST/Pretty.hs7
-rw-r--r--src/AST/Types.hs95
3 files changed, 105 insertions, 0 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 6a00e83..40a46f6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -46,6 +46,7 @@ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-- | This code is executed many times
scaleMany :: Occ -> Occ
+scaleMany (Occ l Zero) = Occ l Zero
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
@@ -124,6 +125,8 @@ occCountGeneral onehot unpush alter many = go WId
EOp _ _ e -> re e
EWith a b -> re a <> re1 b
EAccum _ a b e -> re a <> re b <> re e
+ EZero _ -> mempty
+ EPlus _ a b -> re a <> re b
EError{} -> mempty
where
re :: Monoid (r env') => Expr x env' t'' -> r env'
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 2ce883b..f5e681a 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -181,6 +181,13 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
+ EZero _ -> return $ showString "zero"
+
+ EPlus _ a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
+
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
new file mode 100644
index 0000000..a3e5080
--- /dev/null
+++ b/src/AST/Types.hs
@@ -0,0 +1,95 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+module AST.Types where
+
+import Data.Int (Int32, Int64)
+import Data.Kind (Type)
+import Data.Type.Equality
+
+import Data
+
+
+data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty
+ deriving (Show, Eq, Ord)
+
+data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+ deriving (Show, Eq, Ord)
+
+type STy :: Ty -> Type
+data STy t where
+ STNil :: STy TNil
+ STPair :: STy a -> STy b -> STy (TPair a b)
+ STEither :: STy a -> STy b -> STy (TEither a b)
+ STMaybe :: STy a -> STy (TMaybe a)
+ STArr :: SNat n -> STy t -> STy (TArr n t)
+ STScal :: SScalTy t -> STy (TScal t)
+ STAccum :: STy t -> STy (TAccum t)
+deriving instance Show (STy t)
+
+instance TestEquality STy where
+ testEquality STNil STNil = Just Refl
+ testEquality STNil _ = Nothing
+ testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STPair{} _ = Nothing
+ testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STEither{} _ = Nothing
+ testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STMaybe{} _ = Nothing
+ testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STArr{} _ = Nothing
+ testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STScal{} _ = Nothing
+ testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STAccum{} _ = Nothing
+
+data SScalTy t where
+ STI32 :: SScalTy TI32
+ STI64 :: SScalTy TI64
+ STF32 :: SScalTy TF32
+ STF64 :: SScalTy TF64
+ STBool :: SScalTy TBool
+deriving instance Show (SScalTy t)
+
+instance TestEquality SScalTy where
+ testEquality STI32 STI32 = Just Refl
+ testEquality STI64 STI64 = Just Refl
+ testEquality STF32 STF32 = Just Refl
+ testEquality STF64 STF64 = Just Refl
+ testEquality STBool STBool = Just Refl
+ testEquality _ _ = Nothing
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
+type TIx = TScal TI64
+
+tIx :: STy TIx
+tIx = STScal STI64
+
+type family ScalRep t where
+ ScalRep TI32 = Int32
+ ScalRep TI64 = Int64
+ ScalRep TF32 = Float
+ ScalRep TF64 = Double
+ ScalRep TBool = Bool
+
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False