summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-12-06 21:23:21 +0100
committerTom Smeding <tom@tomsmeding.com>2024-12-06 21:23:21 +0100
commita2d7ddd2230b7f42fe46eb33ea6dee8eb7080fdc (patch)
tree990e51813d4d2438a3234f8a6dac67236b2d0c1e
parent263a3cef7543dfd447d5d75cc759fe95f7864105 (diff)
Start WIP compile-to-C
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/Compile.hs106
-rw-r--r--src/Data.hs13
3 files changed, 120 insertions, 0 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ad611e8..6635f6c 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -25,6 +25,7 @@ library
CHAD.EnvDescr
CHAD.Top
CHAD.Types
+ Compile
-- CompileCu
Data
Example
diff --git a/src/Compile.hs b/src/Compile.hs
new file mode 100644
index 0000000..b65e643
--- /dev/null
+++ b/src/Compile.hs
@@ -0,0 +1,106 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+module Compile where
+
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+
+import AST
+import Data
+
+
+-- In shape and index arrays, the innermost dimension is on the right (last index).
+
+
+data StructDecl = StructDecl
+ String -- ^ name
+ String -- ^ contents
+
+printStructDecl :: StructDecl -> ShowS
+printStructDecl (StructDecl name contents) =
+ showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n"
+
+repTy :: STy t -> String
+repTy (STScal st) = case st of
+ STI32 -> "int32_t"
+ STI64 -> "int64_t"
+ STF32 -> "float"
+ STF64 -> "double"
+ STBool -> "bool"
+repTy t = genStructName t
+
+genStructName :: STy t -> String
+genStructName = \t -> "ty_" ++ gen t where
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'p' : gen a ++ gen b
+ gen (STEither a b) = 'e' : gen a ++ gen b
+ gen (STMaybe t) = 'm' : gen t
+ gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i4"
+ STI64 -> "i8"
+ STF32 -> "f4"
+ STF64 -> "f8"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen t
+
+genStruct :: STy t -> Map String StructDecl
+genStruct STNil =
+ Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "")
+genStruct (STPair a b) =
+ let name = genStructName (STPair a b)
+ in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;"))
+genStruct (STEither a b) =
+ let name = genStructName (STEither a b)
+ in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };"))
+genStruct (STMaybe t) =
+ let name = genStructName (STMaybe t)
+ in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;"))
+genStruct (STArr n t) =
+ let name = genStructName (STArr n t)
+ in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;"))
+genStruct (STScal _) = mempty
+genStruct (STAccum t) =
+ let name = genStructName (STAccum t)
+ in Map.singleton name (StructDecl name (genStructName t ++ " a;"))
+ <> genStruct t
+
+compile :: Ex env t -> (Map String StructDecl, ())
+compile = \case
+ EVar _ _ _ -> mempty
+ ELet _ rhs body -> compile rhs <> compile body
+ EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b
+ EFst _ e -> compile e
+ ESnd _ e -> compile e
+ ENil _ -> mempty
+ EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e
+ EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e
+ ECase _ e a b -> compile e <> compile a <> compile b
+ ENothing _ _ -> mempty
+ EJust _ e -> compile e
+ EMaybe _ a b e -> compile a <> compile b <> compile e
+ EConstArr _ n t _ -> genStruct (STArr n (STScal t))
+ EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b)
+ EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c)
+ ESum1Inner _ e -> ESum1Inner ext (compile e)
+ EUnit _ e -> EUnit ext (compile e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b)
+ EMaximum1Inner _ e -> EMaximum1Inner ext (compile e)
+ EMinimum1Inner _ e -> EMinimum1Inner ext (compile e)
+ EConst _ t x -> EConst ext t x
+ EIdx0 _ e -> EIdx0 ext (compile e)
+ EIdx1 _ a b -> EIdx1 ext (compile a) (compile b)
+ EIdx _ a b -> EIdx ext (compile a) (compile b)
+ EShape _ e -> EShape ext (compile e)
+ EOp _ op e -> EOp ext op (compile e)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2)
+ EWith a b -> EWith (compile a) (compile b)
+ EAccum n a b e -> EAccum n (compile a) (compile b) (compile e)
+ EZero t -> zero t
+ EPlus t a b -> plus t a b
+ EOneHot t i a b -> onehot t i a b
+ EError t s -> EError t s
+
+compose :: Foldable t => t (a -> a) -> a -> a
+compose = foldr (.) id
diff --git a/src/Data.hs b/src/Data.hs
index 1371902..fc39814 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -111,3 +111,16 @@ vecGenerate = \n f -> go n f SZ
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = unsafeCoerce Refl
+
+data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t]
+ deriving (Show, Functor, Foldable, Traversable)
+
+instance Applicative Bag where
+ pure = BOne
+ BNone <*> _ = BNone
+ BOne f <*> b = f <$> b
+ BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b)
+ BMany bs <*> b = BMany (map (<*> b) bs)
+
+instance Semigroup (Bag t) where (<>) = BTwo
+instance Monoid (Bag t) where mempty = BNone