summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-12-12 16:30:42 +0100
committerTom Smeding <t.j.smeding@uu.nl>2024-12-12 16:30:42 +0100
commitf323076ddf6fbea9f7a1a4dfeec98629459c49fc (patch)
treed19137df436ac388c79bffe55df4a1b9b29316a8
parentfad10d5a218f935d47e8b9dc41256a30b4ec540d (diff)
Somewhat working Compile
-rw-r--r--src/AST.hs19
-rw-r--r--src/Compile.hs392
2 files changed, 351 insertions, 60 deletions
diff --git a/src/AST.hs b/src/AST.hs
index fff290a..333f306 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -175,6 +175,25 @@ data SOp a t where
OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
deriving instance Show (SOp a t)
+opt1 :: SOp a t -> STy a
+opt1 = \case
+ OAdd t -> STPair (STScal t) (STScal t)
+ OMul t -> STPair (STScal t) (STScal t)
+ ONeg t -> STScal t
+ OLt t -> STPair (STScal t) (STScal t)
+ OLe t -> STPair (STScal t) (STScal t)
+ OEq t -> STPair (STScal t) (STScal t)
+ ONot -> STScal STBool
+ OAnd -> STPair (STScal STBool) (STScal STBool)
+ OOr -> STPair (STScal STBool) (STScal STBool)
+ OIf -> STScal STBool
+ ORound64 -> STScal STF64
+ OToFl64 -> STScal STI64
+ ORecip t -> STScal t
+ OExp t -> STScal t
+ OLog t -> STScal t
+ OIDiv t -> STPair (STScal t) (STScal t)
+
opt2 :: SOp a t -> STy t
opt2 = \case
OAdd t -> STScal t
diff --git a/src/Compile.hs b/src/Compile.hs
index b65e643..83c25c3 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,11 +1,18 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE TypeApplications #-}
module Compile where
+import Control.Monad.Trans.State.Strict
+import Data.Foldable (toList)
+import Data.Functor.Const
+import Data.List (intersperse, intercalate)
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import AST
+import AST.Pretty (ppTy)
import Data
@@ -13,12 +20,63 @@ import Data
data StructDecl = StructDecl
- String -- ^ name
- String -- ^ contents
+ String -- ^ name
+ String -- ^ contents
+ String -- ^ comment
+ deriving (Show)
+
+data Stmt
+ = SVarDecl String String CExpr -- ^ type, variable name, right-hand side
+ | SVarDeclUninit String String -- ^ type, variable name (no initialiser)
+ | SAsg String CExpr -- ^ variable name, right-hand side
+ | SBlock [Stmt]
+ | SIf CExpr [Stmt] [Stmt]
+ | SVerbatim String
+ deriving (Show)
+
+data CExpr
+ = CELit String
+ | CEStruct String [(String, CExpr)]
+ | CEProj CExpr String
+ | CECall String [CExpr]
+ | CEBinop CExpr String CExpr
+ | CEIf CExpr CExpr CExpr
+ deriving (Show)
printStructDecl :: StructDecl -> ShowS
-printStructDecl (StructDecl name contents) =
- showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n"
+printStructDecl (StructDecl name contents comment) =
+ showString "typedef struct { " . showString contents . showString " } " . showString name
+ . showString ("; // " ++ comment)
+
+printStmt :: Int -> Stmt -> ShowS
+printStmt indent = \case
+ SVarDecl typ name rhs -> showString (typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";"
+ SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
+ SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";"
+ SBlock stmts ->
+ showString "{"
+ . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts]
+ . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
+ SIf cond b1 b2 ->
+ showString "if (" . printCExpr cond . showString ") "
+ . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2)
+ SVerbatim s -> showString s
+
+printCExpr :: CExpr -> ShowS
+printCExpr = \case
+ CELit s -> showString s
+ CEStruct name pairs ->
+ showString ("(" ++ name ++ "){")
+ . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e
+ | (n, e) <- pairs])
+ . showString "}"
+ CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name)
+ CECall n es ->
+ showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")"
+ CEBinop e1 n e2 ->
+ showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")"
+ CEIf e1 e2 e3 ->
+ printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3
repTy :: STy t -> String
repTy (STScal st) = case st of
@@ -31,12 +89,13 @@ repTy t = genStructName t
genStructName :: STy t -> String
genStructName = \t -> "ty_" ++ gen t where
+ -- all tags start with a letter, so the array mangling is unambiguous.
gen :: STy t -> String
gen STNil = "n"
gen (STPair a b) = 'p' : gen a ++ gen b
gen (STEither a b) = 'e' : gen a ++ gen b
gen (STMaybe t) = 'm' : gen t
- gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
gen (STScal st) = case st of
STI32 -> "i4"
STI64 -> "i8"
@@ -46,61 +105,274 @@ genStructName = \t -> "ty_" ++ gen t where
gen (STAccum t) = 'C' : gen t
genStruct :: STy t -> Map String StructDecl
-genStruct STNil =
- Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "")
-genStruct (STPair a b) =
- let name = genStructName (STPair a b)
- in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;"))
-genStruct (STEither a b) =
- let name = genStructName (STEither a b)
- in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };"))
-genStruct (STMaybe t) =
- let name = genStructName (STMaybe t)
- in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;"))
-genStruct (STArr n t) =
- let name = genStructName (STArr n t)
- in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;"))
-genStruct (STScal _) = mempty
-genStruct (STAccum t) =
- let name = genStructName (STAccum t)
- in Map.singleton name (StructDecl name (genStructName t ++ " a;"))
- <> genStruct t
-
-compile :: Ex env t -> (Map String StructDecl, ())
-compile = \case
- EVar _ _ _ -> mempty
- ELet _ rhs body -> compile rhs <> compile body
- EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b
- EFst _ e -> compile e
- ESnd _ e -> compile e
- ENil _ -> mempty
- EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e
- EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e
- ECase _ e a b -> compile e <> compile a <> compile b
- ENothing _ _ -> mempty
- EJust _ e -> compile e
- EMaybe _ a b e -> compile a <> compile b <> compile e
- EConstArr _ n t _ -> genStruct (STArr n (STScal t))
- EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b)
- EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c)
- ESum1Inner _ e -> ESum1Inner ext (compile e)
- EUnit _ e -> EUnit ext (compile e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b)
- EMaximum1Inner _ e -> EMaximum1Inner ext (compile e)
- EMinimum1Inner _ e -> EMinimum1Inner ext (compile e)
- EConst _ t x -> EConst ext t x
- EIdx0 _ e -> EIdx0 ext (compile e)
- EIdx1 _ a b -> EIdx1 ext (compile a) (compile b)
- EIdx _ a b -> EIdx ext (compile a) (compile b)
- EShape _ e -> EShape ext (compile e)
- EOp _ op e -> EOp ext op (compile e)
- ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2)
- EWith a b -> EWith (compile a) (compile b)
- EAccum n a b e -> EAccum n (compile a) (compile b) (compile e)
- EZero t -> zero t
- EPlus t a b -> plus t a b
- EOneHot t i a b -> onehot t i a b
- EError t s -> EError t s
+genStruct topty = case topty of
+ STNil ->
+ Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com)
+ STPair a b ->
+ let name = genStructName (STPair a b)
+ in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;") com)
+ STEither a b ->
+ let name = genStructName (STEither a b) -- 0 -> a, 1 -> b
+ in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };") com)
+ STMaybe t ->
+ let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just
+ in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;") com)
+ STArr n t ->
+ let name = genStructName (STArr n t)
+ in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;") com)
+ STScal _ -> mempty
+ STAccum t ->
+ let name = genStructName (STAccum t)
+ in Map.singleton name (StructDecl name (genStructName t ++ " a;") com)
+ <> genStruct t
+ where
+ com = ppTy 0 topty
+
+data CompState = CompState
+ { csStructs :: Map String StructDecl
+ , csStmts :: Bag Stmt
+ , csNextId :: Int }
+ deriving (Show)
+
+type CompM a = State CompState a
+
+genId :: CompM Int
+genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+
+genName :: CompM String
+genName = ('x' :) . show <$> genId
+
+emit :: Stmt -> CompM ()
+emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt }
+
+scope :: CompM a -> CompM (a, [Stmt])
+scope m = do
+ stmts <- state $ \s -> (csStmts s, s { csStmts = mempty })
+ res <- m
+ innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts })
+ return (res, toList innerStmts)
+
+emitStruct :: STy t -> CompM String
+emitStruct ty = do
+ modify $ \s -> s { csStructs = genStruct ty <> csStructs s }
+ return (genStructName ty)
+
+compile :: SList (Const String) env -> Ex env t -> String
+compile env expr =
+ let (res, s) = runState (compile' env expr) (CompState mempty mempty 1)
+ in ($ "") $ compose
+ [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s))
+ ,showString "\n"
+ ,showString (genStructName (typeOf expr) ++ " kernel(" ++ intercalate ", " (reverse (unSList getConst env)) ++ ") {\n")
+ ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s))
+ ,showString (" return ") . printCExpr res . showString ";\n}\n"]
+
+compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
+compile' env = \case
+ EVar _ _ i -> return $ CELit (getConst (slistIdx env i))
+
+ ELet _ rhs body -> do
+ e <- compile' env rhs
+ var <- genName
+ emit $ SVarDecl (genStructName (typeOf rhs)) var e
+ compile' (Const var `SCons` env) body
+
+ EPair _ a b -> do
+ name <- emitStruct (STPair (typeOf a) (typeOf b))
+ e1 <- compile' env a
+ e2 <- compile' env b
+ return $ CEStruct name [("a", e1), ("b", e2)]
+
+ EFst _ e -> CEProj <$> compile' env e <*> pure "a"
+
+ ESnd _ e -> CEProj <$> compile' env e <*> pure "b"
+
+ ENil _ -> do
+ name <- emitStruct STNil
+ return $ CEStruct name []
+
+ EInl _ t e -> do
+ name <- emitStruct (STEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "0"), ("a", e1)]
+
+ EInr _ t e -> do
+ name <- emitStruct (STEither t (typeOf e))
+ e2 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("b", e2)]
+
+ ECase _ e a b -> do
+ let STEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ fieldvar <- genName
+ (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a
+ (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
+ retvar <- genName
+ emit $ SVarDeclUninit (genStructName (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (pure (SVarDecl (genStructName t1) fieldvar
+ (CEProj (CELit var) "a"))
+ <> stmts2
+ <> pure (SAsg retvar e2))
+ (pure (SVarDecl (genStructName t2) fieldvar
+ (CEProj (CELit var) "b"))
+ <> stmts3
+ <> pure (SAsg retvar e3))))
+ return (CELit retvar)
+
+ ENothing _ t -> do
+ name <- emitStruct (STMaybe t)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ EJust _ e -> do
+ name <- emitStruct (STMaybe (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("a", e1)]
+
+ EMaybe _ a b e -> do
+ e1 <- compile' env e
+ var <- genName
+ fieldvar <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
+ retvar <- genName
+ emit $ SVarDeclUninit (genStructName (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2
+ <> pure (SAsg retvar e2))
+ (pure (SVarDecl (genStructName (typeOf b)) fieldvar
+ (CEProj (CELit var) "a"))
+ <> stmts3
+ <> pure (SAsg retvar e3))))
+ return (CELit retvar)
+
+ EConstArr _ n t arr -> do
+ name <- emitStruct (STArr n (STScal t))
+ error "TODO"
+
+ EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b)
+
+ EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c)
+
+ ESum1Inner _ e -> error "TODO" -- ESum1Inner ext (compile' e)
+
+ EUnit _ e -> error "TODO" -- EUnit ext (compile' e)
+
+ EReplicate1Inner _ a b -> error "TODO" -- EReplicate1Inner ext (compile' a) (compile' b)
+
+ EMaximum1Inner _ e -> error "TODO" -- EMaximum1Inner ext (compile' e)
+
+ EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e)
+
+ EConst _ t x -> case t of
+ STI32 -> return $ CELit $ "(int32_t)" ++ show x
+ STI64 -> return $ CELit $ "(int64_t)" ++ show x
+ STF32 -> return $ CELit $ show x ++ "f"
+ STF64 -> return $ CELit $ show x
+ STBool -> return $ CELit $ if x then "true" else "false"
+
+ EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e)
+
+ EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
+
+ EIdx _ a b -> error "TODO" -- EIdx ext (compile' a) (compile' b)
+
+ EShape _ e -> error "TODO" -- EShape ext (compile' e)
+
+ EOp _ op e -> do
+ e1 <- compile' env e
+ let unary cop = return @(State CompState) $ CECall cop [e1]
+ let binary cop = do
+ name <- genName
+ emit $ SVarDecl (genStructName (typeOf e)) name e1
+ return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ ONeg _ -> unary "-"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ ONot -> unary "!"
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIf -> do
+ name <- emitStruct (STEither STNil STNil)
+ _ <- emitStruct STNil
+ return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
+ (CEStruct name [("tag", CELit "1")])
+ ORound64 -> unary "(int64_t)round" -- ew
+ OToFl64 -> unary "(double)"
+ ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
+ OExp STF32 -> unary "expf"
+ OExp STF64 -> unary "exp"
+ OLog STF32 -> unary "logf"
+ OLog STF64 -> unary "log"
+ OIDiv _ -> binary "/"
+
+ ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2)
+
+ EWith a b -> error "TODO" -- EWith (compile' a) (compile' b)
+
+ EAccum n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e)
+
+ EError t s -> do
+ name <- emitStruct t
+ -- using 'show' here is wrong, but it's good enough for me.
+ emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);"
+ return $ CEStruct name []
+
+ EZero{} -> error "Compile: monoid operations should have been eliminated"
+ EPlus{} -> error "Compile: monoid operations should have been eliminated"
+ EOneHot{} -> error "Compile: monoid operations should have been eliminated"
+
+compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
+compileOpGeneral op e1 = do
+ let unary cop = return @(State CompState) $ CECall cop [e1]
+ let binary cop = do
+ name <- genName
+ emit $ SVarDecl (genStructName (opt1 op)) name e1
+ return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ ONeg _ -> unary "-"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ ONot -> unary "!"
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIf -> do
+ name <- emitStruct (STEither STNil STNil)
+ _ <- emitStruct STNil
+ return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
+ (CEStruct name [("tag", CELit "1")])
+ ORound64 -> unary "(int64_t)round" -- ew
+ OToFl64 -> unary "(double)"
+ ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
+ OExp STF32 -> unary "expf"
+ OExp STF64 -> unary "exp"
+ OLog STF32 -> unary "logf"
+ OLog STF64 -> unary "log"
+ OIDiv _ -> binary "/"
+
+compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
+compileOpPair op e1 e2 = do
+ let binary cop = return @(State CompState) $ CEBinop e1 cop e2
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIDiv _ -> binary "/"
+ _ -> error "compileOpPair: got unary operator"
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id