From a2d7ddd2230b7f42fe46eb33ea6dee8eb7080fdc Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Fri, 6 Dec 2024 21:23:21 +0100
Subject: Start WIP compile-to-C

---
 src/Compile.hs | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 106 insertions(+)
 create mode 100644 src/Compile.hs

(limited to 'src/Compile.hs')

diff --git a/src/Compile.hs b/src/Compile.hs
new file mode 100644
index 0000000..b65e643
--- /dev/null
+++ b/src/Compile.hs
@@ -0,0 +1,106 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+module Compile where
+
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+
+import AST
+import Data
+
+
+-- In shape and index arrays, the innermost dimension is on the right (last index).
+
+
+data StructDecl = StructDecl
+  String  -- ^ name
+  String  -- ^ contents
+
+printStructDecl :: StructDecl -> ShowS
+printStructDecl (StructDecl name contents) =
+  showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n"
+
+repTy :: STy t -> String
+repTy (STScal st) = case st of
+  STI32 -> "int32_t"
+  STI64 -> "int64_t"
+  STF32 -> "float"
+  STF64 -> "double"
+  STBool -> "bool"
+repTy t = genStructName t
+
+genStructName :: STy t -> String
+genStructName = \t -> "ty_" ++ gen t where
+  gen :: STy t -> String
+  gen STNil = "n"
+  gen (STPair a b) = 'p' : gen a ++ gen b
+  gen (STEither a b) = 'e' : gen a ++ gen b
+  gen (STMaybe t) = 'm' : gen t
+  gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t
+  gen (STScal st) = case st of
+    STI32 -> "i4"
+    STI64 -> "i8"
+    STF32 -> "f4"
+    STF64 -> "f8"
+    STBool -> "b"
+  gen (STAccum t) = 'C' : gen t
+
+genStruct :: STy t -> Map String StructDecl
+genStruct STNil =
+  Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "")
+genStruct (STPair a b) =
+  let name = genStructName (STPair a b)
+  in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;"))
+genStruct (STEither a b) =
+  let name = genStructName (STEither a b)
+  in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };"))
+genStruct (STMaybe t) =
+  let name = genStructName (STMaybe t)
+  in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;"))
+genStruct (STArr n t) =
+  let name = genStructName (STArr n t)
+  in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;"))
+genStruct (STScal _) = mempty
+genStruct (STAccum t) =
+  let name = genStructName (STAccum t)
+  in Map.singleton name (StructDecl name (genStructName t ++ " a;"))
+       <> genStruct t
+
+compile :: Ex env t -> (Map String StructDecl, ())
+compile = \case
+  EVar _ _ _ -> mempty
+  ELet _ rhs body -> compile rhs <> compile body
+  EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b
+  EFst _ e -> compile e
+  ESnd _ e -> compile e
+  ENil _ -> mempty
+  EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e
+  EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e
+  ECase _ e a b -> compile e <> compile a <> compile b
+  ENothing _ _ -> mempty
+  EJust _ e -> compile e
+  EMaybe _ a b e -> compile a <> compile b <> compile e
+  EConstArr _ n t _ -> genStruct (STArr n (STScal t))
+  EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b)
+  EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c)
+  ESum1Inner _ e -> ESum1Inner ext (compile e)
+  EUnit _ e -> EUnit ext (compile e)
+  EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b)
+  EMaximum1Inner _ e -> EMaximum1Inner ext (compile e)
+  EMinimum1Inner _ e -> EMinimum1Inner ext (compile e)
+  EConst _ t x -> EConst ext t x
+  EIdx0 _ e -> EIdx0 ext (compile e)
+  EIdx1 _ a b -> EIdx1 ext (compile a) (compile b)
+  EIdx _ a b -> EIdx ext (compile a) (compile b)
+  EShape _ e -> EShape ext (compile e)
+  EOp _ op e -> EOp ext op (compile e)
+  ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2)
+  EWith a b -> EWith (compile a) (compile b)
+  EAccum n a b e -> EAccum n (compile a) (compile b) (compile e)
+  EZero t -> zero t
+  EPlus t a b -> plus t a b
+  EOneHot t i a b -> onehot t i a b
+  EError t s -> EError t s
+
+compose :: Foldable t => t (a -> a) -> a -> a
+compose = foldr (.) id
-- 
cgit v1.2.3-70-g09d2