diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-12-12 16:30:42 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-12-12 16:30:42 +0100 |
commit | f323076ddf6fbea9f7a1a4dfeec98629459c49fc (patch) | |
tree | d19137df436ac388c79bffe55df4a1b9b29316a8 | |
parent | fad10d5a218f935d47e8b9dc41256a30b4ec540d (diff) |
Somewhat working Compile
-rw-r--r-- | src/AST.hs | 19 | ||||
-rw-r--r-- | src/Compile.hs | 392 |
2 files changed, 351 insertions, 60 deletions
@@ -175,6 +175,25 @@ data SOp a t where OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) deriving instance Show (SOp a t) +opt1 :: SOp a t -> STy a +opt1 = \case + OAdd t -> STPair (STScal t) (STScal t) + OMul t -> STPair (STScal t) (STScal t) + ONeg t -> STScal t + OLt t -> STPair (STScal t) (STScal t) + OLe t -> STPair (STScal t) (STScal t) + OEq t -> STPair (STScal t) (STScal t) + ONot -> STScal STBool + OAnd -> STPair (STScal STBool) (STScal STBool) + OOr -> STPair (STScal STBool) (STScal STBool) + OIf -> STScal STBool + ORound64 -> STScal STF64 + OToFl64 -> STScal STI64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STPair (STScal t) (STScal t) + opt2 :: SOp a t -> STy t opt2 = \case OAdd t -> STScal t diff --git a/src/Compile.hs b/src/Compile.hs index b65e643..83c25c3 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,11 +1,18 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE TypeApplications #-} module Compile where +import Control.Monad.Trans.State.Strict +import Data.Foldable (toList) +import Data.Functor.Const +import Data.List (intersperse, intercalate) import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import AST +import AST.Pretty (ppTy) import Data @@ -13,12 +20,63 @@ import Data data StructDecl = StructDecl - String -- ^ name - String -- ^ contents + String -- ^ name + String -- ^ contents + String -- ^ comment + deriving (Show) + +data Stmt + = SVarDecl String String CExpr -- ^ type, variable name, right-hand side + | SVarDeclUninit String String -- ^ type, variable name (no initialiser) + | SAsg String CExpr -- ^ variable name, right-hand side + | SBlock [Stmt] + | SIf CExpr [Stmt] [Stmt] + | SVerbatim String + deriving (Show) + +data CExpr + = CELit String + | CEStruct String [(String, CExpr)] + | CEProj CExpr String + | CECall String [CExpr] + | CEBinop CExpr String CExpr + | CEIf CExpr CExpr CExpr + deriving (Show) printStructDecl :: StructDecl -> ShowS -printStructDecl (StructDecl name contents) = - showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n" +printStructDecl (StructDecl name contents comment) = + showString "typedef struct { " . showString contents . showString " } " . showString name + . showString ("; // " ++ comment) + +printStmt :: Int -> Stmt -> ShowS +printStmt indent = \case + SVarDecl typ name rhs -> showString (typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" + SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") + SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";" + SBlock stmts -> + showString "{" + . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts] + . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") + SIf cond b1 b2 -> + showString "if (" . printCExpr cond . showString ") " + . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2) + SVerbatim s -> showString s + +printCExpr :: CExpr -> ShowS +printCExpr = \case + CELit s -> showString s + CEStruct name pairs -> + showString ("(" ++ name ++ "){") + . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e + | (n, e) <- pairs]) + . showString "}" + CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name) + CECall n es -> + showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")" + CEBinop e1 n e2 -> + showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")" + CEIf e1 e2 e3 -> + printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3 repTy :: STy t -> String repTy (STScal st) = case st of @@ -31,12 +89,13 @@ repTy t = genStructName t genStructName :: STy t -> String genStructName = \t -> "ty_" ++ gen t where + -- all tags start with a letter, so the array mangling is unambiguous. gen :: STy t -> String gen STNil = "n" gen (STPair a b) = 'p' : gen a ++ gen b gen (STEither a b) = 'e' : gen a ++ gen b gen (STMaybe t) = 'm' : gen t - gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t gen (STScal st) = case st of STI32 -> "i4" STI64 -> "i8" @@ -46,61 +105,274 @@ genStructName = \t -> "ty_" ++ gen t where gen (STAccum t) = 'C' : gen t genStruct :: STy t -> Map String StructDecl -genStruct STNil = - Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "") -genStruct (STPair a b) = - let name = genStructName (STPair a b) - in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;")) -genStruct (STEither a b) = - let name = genStructName (STEither a b) - in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };")) -genStruct (STMaybe t) = - let name = genStructName (STMaybe t) - in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;")) -genStruct (STArr n t) = - let name = genStructName (STArr n t) - in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;")) -genStruct (STScal _) = mempty -genStruct (STAccum t) = - let name = genStructName (STAccum t) - in Map.singleton name (StructDecl name (genStructName t ++ " a;")) - <> genStruct t - -compile :: Ex env t -> (Map String StructDecl, ()) -compile = \case - EVar _ _ _ -> mempty - ELet _ rhs body -> compile rhs <> compile body - EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b - EFst _ e -> compile e - ESnd _ e -> compile e - ENil _ -> mempty - EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e - EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e - ECase _ e a b -> compile e <> compile a <> compile b - ENothing _ _ -> mempty - EJust _ e -> compile e - EMaybe _ a b e -> compile a <> compile b <> compile e - EConstArr _ n t _ -> genStruct (STArr n (STScal t)) - EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b) - EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c) - ESum1Inner _ e -> ESum1Inner ext (compile e) - EUnit _ e -> EUnit ext (compile e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b) - EMaximum1Inner _ e -> EMaximum1Inner ext (compile e) - EMinimum1Inner _ e -> EMinimum1Inner ext (compile e) - EConst _ t x -> EConst ext t x - EIdx0 _ e -> EIdx0 ext (compile e) - EIdx1 _ a b -> EIdx1 ext (compile a) (compile b) - EIdx _ a b -> EIdx ext (compile a) (compile b) - EShape _ e -> EShape ext (compile e) - EOp _ op e -> EOp ext op (compile e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2) - EWith a b -> EWith (compile a) (compile b) - EAccum n a b e -> EAccum n (compile a) (compile b) (compile e) - EZero t -> zero t - EPlus t a b -> plus t a b - EOneHot t i a b -> onehot t i a b - EError t s -> EError t s +genStruct topty = case topty of + STNil -> + Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com) + STPair a b -> + let name = genStructName (STPair a b) + in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;") com) + STEither a b -> + let name = genStructName (STEither a b) -- 0 -> a, 1 -> b + in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };") com) + STMaybe t -> + let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just + in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;") com) + STArr n t -> + let name = genStructName (STArr n t) + in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;") com) + STScal _ -> mempty + STAccum t -> + let name = genStructName (STAccum t) + in Map.singleton name (StructDecl name (genStructName t ++ " a;") com) + <> genStruct t + where + com = ppTy 0 topty + +data CompState = CompState + { csStructs :: Map String StructDecl + , csStmts :: Bag Stmt + , csNextId :: Int } + deriving (Show) + +type CompM a = State CompState a + +genId :: CompM Int +genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) + +genName :: CompM String +genName = ('x' :) . show <$> genId + +emit :: Stmt -> CompM () +emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } + +scope :: CompM a -> CompM (a, [Stmt]) +scope m = do + stmts <- state $ \s -> (csStmts s, s { csStmts = mempty }) + res <- m + innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts }) + return (res, toList innerStmts) + +emitStruct :: STy t -> CompM String +emitStruct ty = do + modify $ \s -> s { csStructs = genStruct ty <> csStructs s } + return (genStructName ty) + +compile :: SList (Const String) env -> Ex env t -> String +compile env expr = + let (res, s) = runState (compile' env expr) (CompState mempty mempty 1) + in ($ "") $ compose + [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s)) + ,showString "\n" + ,showString (genStructName (typeOf expr) ++ " kernel(" ++ intercalate ", " (reverse (unSList getConst env)) ++ ") {\n") + ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) + ,showString (" return ") . printCExpr res . showString ";\n}\n"] + +compile' :: SList (Const String) env -> Ex env t -> CompM CExpr +compile' env = \case + EVar _ _ i -> return $ CELit (getConst (slistIdx env i)) + + ELet _ rhs body -> do + e <- compile' env rhs + var <- genName + emit $ SVarDecl (genStructName (typeOf rhs)) var e + compile' (Const var `SCons` env) body + + EPair _ a b -> do + name <- emitStruct (STPair (typeOf a) (typeOf b)) + e1 <- compile' env a + e2 <- compile' env b + return $ CEStruct name [("a", e1), ("b", e2)] + + EFst _ e -> CEProj <$> compile' env e <*> pure "a" + + ESnd _ e -> CEProj <$> compile' env e <*> pure "b" + + ENil _ -> do + name <- emitStruct STNil + return $ CEStruct name [] + + EInl _ t e -> do + name <- emitStruct (STEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "0"), ("a", e1)] + + EInr _ t e -> do + name <- emitStruct (STEither t (typeOf e)) + e2 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("b", e2)] + + ECase _ e a b -> do + let STEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + fieldvar <- genName + (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a + (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b + retvar <- genName + emit $ SVarDeclUninit (genStructName (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (pure (SVarDecl (genStructName t1) fieldvar + (CEProj (CELit var) "a")) + <> stmts2 + <> pure (SAsg retvar e2)) + (pure (SVarDecl (genStructName t2) fieldvar + (CEProj (CELit var) "b")) + <> stmts3 + <> pure (SAsg retvar e3)))) + return (CELit retvar) + + ENothing _ t -> do + name <- emitStruct (STMaybe t) + return $ CEStruct name [("tag", CELit "0")] + + EJust _ e -> do + name <- emitStruct (STMaybe (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("a", e1)] + + EMaybe _ a b e -> do + e1 <- compile' env e + var <- genName + fieldvar <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b + retvar <- genName + emit $ SVarDeclUninit (genStructName (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 + <> pure (SAsg retvar e2)) + (pure (SVarDecl (genStructName (typeOf b)) fieldvar + (CEProj (CELit var) "a")) + <> stmts3 + <> pure (SAsg retvar e3)))) + return (CELit retvar) + + EConstArr _ n t arr -> do + name <- emitStruct (STArr n (STScal t)) + error "TODO" + + EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b) + + EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) + + ESum1Inner _ e -> error "TODO" -- ESum1Inner ext (compile' e) + + EUnit _ e -> error "TODO" -- EUnit ext (compile' e) + + EReplicate1Inner _ a b -> error "TODO" -- EReplicate1Inner ext (compile' a) (compile' b) + + EMaximum1Inner _ e -> error "TODO" -- EMaximum1Inner ext (compile' e) + + EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e) + + EConst _ t x -> case t of + STI32 -> return $ CELit $ "(int32_t)" ++ show x + STI64 -> return $ CELit $ "(int64_t)" ++ show x + STF32 -> return $ CELit $ show x ++ "f" + STF64 -> return $ CELit $ show x + STBool -> return $ CELit $ if x then "true" else "false" + + EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e) + + EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) + + EIdx _ a b -> error "TODO" -- EIdx ext (compile' a) (compile' b) + + EShape _ e -> error "TODO" -- EShape ext (compile' e) + + EOp _ op e -> do + e1 <- compile' env e + let unary cop = return @(State CompState) $ CECall cop [e1] + let binary cop = do + name <- genName + emit $ SVarDecl (genStructName (typeOf e)) name e1 + return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + ONeg _ -> unary "-" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + ONot -> unary "!" + OAnd -> binary "&&" + OOr -> binary "||" + OIf -> do + name <- emitStruct (STEither STNil STNil) + _ <- emitStruct STNil + return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) + (CEStruct name [("tag", CELit "1")]) + ORound64 -> unary "(int64_t)round" -- ew + OToFl64 -> unary "(double)" + ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 + OExp STF32 -> unary "expf" + OExp STF64 -> unary "exp" + OLog STF32 -> unary "logf" + OLog STF64 -> unary "log" + OIDiv _ -> binary "/" + + ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) + + EWith a b -> error "TODO" -- EWith (compile' a) (compile' b) + + EAccum n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) + + EError t s -> do + name <- emitStruct t + -- using 'show' here is wrong, but it's good enough for me. + emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);" + return $ CEStruct name [] + + EZero{} -> error "Compile: monoid operations should have been eliminated" + EPlus{} -> error "Compile: monoid operations should have been eliminated" + EOneHot{} -> error "Compile: monoid operations should have been eliminated" + +compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr +compileOpGeneral op e1 = do + let unary cop = return @(State CompState) $ CECall cop [e1] + let binary cop = do + name <- genName + emit $ SVarDecl (genStructName (opt1 op)) name e1 + return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + ONeg _ -> unary "-" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + ONot -> unary "!" + OAnd -> binary "&&" + OOr -> binary "||" + OIf -> do + name <- emitStruct (STEither STNil STNil) + _ <- emitStruct STNil + return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) + (CEStruct name [("tag", CELit "1")]) + ORound64 -> unary "(int64_t)round" -- ew + OToFl64 -> unary "(double)" + ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 + OExp STF32 -> unary "expf" + OExp STF64 -> unary "exp" + OLog STF32 -> unary "logf" + OLog STF64 -> unary "log" + OIDiv _ -> binary "/" + +compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr +compileOpPair op e1 e2 = do + let binary cop = return @(State CompState) $ CEBinop e1 cop e2 + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + OAnd -> binary "&&" + OOr -> binary "||" + OIDiv _ -> binary "/" + _ -> error "compileOpPair: got unary operator" compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id |