diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 21:23:21 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 21:23:21 +0100 |
commit | a2d7ddd2230b7f42fe46eb33ea6dee8eb7080fdc (patch) | |
tree | 990e51813d4d2438a3234f8a6dac67236b2d0c1e | |
parent | 263a3cef7543dfd447d5d75cc759fe95f7864105 (diff) |
Start WIP compile-to-C
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/Compile.hs | 106 | ||||
-rw-r--r-- | src/Data.hs | 13 |
3 files changed, 120 insertions, 0 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index ad611e8..6635f6c 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -25,6 +25,7 @@ library CHAD.EnvDescr CHAD.Top CHAD.Types + Compile -- CompileCu Data Example diff --git a/src/Compile.hs b/src/Compile.hs new file mode 100644 index 0000000..b65e643 --- /dev/null +++ b/src/Compile.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +module Compile where + +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) + +import AST +import Data + + +-- In shape and index arrays, the innermost dimension is on the right (last index). + + +data StructDecl = StructDecl + String -- ^ name + String -- ^ contents + +printStructDecl :: StructDecl -> ShowS +printStructDecl (StructDecl name contents) = + showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n" + +repTy :: STy t -> String +repTy (STScal st) = case st of + STI32 -> "int32_t" + STI64 -> "int64_t" + STF32 -> "float" + STF64 -> "double" + STBool -> "bool" +repTy t = genStructName t + +genStructName :: STy t -> String +genStructName = \t -> "ty_" ++ gen t where + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'p' : gen a ++ gen b + gen (STEither a b) = 'e' : gen a ++ gen b + gen (STMaybe t) = 'm' : gen t + gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t + gen (STScal st) = case st of + STI32 -> "i4" + STI64 -> "i8" + STF32 -> "f4" + STF64 -> "f8" + STBool -> "b" + gen (STAccum t) = 'C' : gen t + +genStruct :: STy t -> Map String StructDecl +genStruct STNil = + Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "") +genStruct (STPair a b) = + let name = genStructName (STPair a b) + in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;")) +genStruct (STEither a b) = + let name = genStructName (STEither a b) + in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };")) +genStruct (STMaybe t) = + let name = genStructName (STMaybe t) + in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;")) +genStruct (STArr n t) = + let name = genStructName (STArr n t) + in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;")) +genStruct (STScal _) = mempty +genStruct (STAccum t) = + let name = genStructName (STAccum t) + in Map.singleton name (StructDecl name (genStructName t ++ " a;")) + <> genStruct t + +compile :: Ex env t -> (Map String StructDecl, ()) +compile = \case + EVar _ _ _ -> mempty + ELet _ rhs body -> compile rhs <> compile body + EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b + EFst _ e -> compile e + ESnd _ e -> compile e + ENil _ -> mempty + EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e + EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e + ECase _ e a b -> compile e <> compile a <> compile b + ENothing _ _ -> mempty + EJust _ e -> compile e + EMaybe _ a b e -> compile a <> compile b <> compile e + EConstArr _ n t _ -> genStruct (STArr n (STScal t)) + EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b) + EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c) + ESum1Inner _ e -> ESum1Inner ext (compile e) + EUnit _ e -> EUnit ext (compile e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b) + EMaximum1Inner _ e -> EMaximum1Inner ext (compile e) + EMinimum1Inner _ e -> EMinimum1Inner ext (compile e) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (compile e) + EIdx1 _ a b -> EIdx1 ext (compile a) (compile b) + EIdx _ a b -> EIdx ext (compile a) (compile b) + EShape _ e -> EShape ext (compile e) + EOp _ op e -> EOp ext op (compile e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2) + EWith a b -> EWith (compile a) (compile b) + EAccum n a b e -> EAccum n (compile a) (compile b) (compile e) + EZero t -> zero t + EPlus t a b -> plus t a b + EOneHot t i a b -> onehot t i a b + EError t s -> EError t s + +compose :: Foldable t => t (a -> a) -> a -> a +compose = foldr (.) id diff --git a/src/Data.hs b/src/Data.hs index 1371902..fc39814 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -111,3 +111,16 @@ vecGenerate = \n f -> go n f SZ unsafeCoerceRefl :: a :~: b unsafeCoerceRefl = unsafeCoerce Refl + +data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t] + deriving (Show, Functor, Foldable, Traversable) + +instance Applicative Bag where + pure = BOne + BNone <*> _ = BNone + BOne f <*> b = f <$> b + BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b) + BMany bs <*> b = BMany (map (<*> b) bs) + +instance Semigroup (Bag t) where (<>) = BTwo +instance Monoid (Bag t) where mempty = BNone |