diff options
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 207 |
1 files changed, 132 insertions, 75 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 05d51c1..95004b8 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -6,12 +6,14 @@ module Compile where import Control.Monad.Trans.State.Strict +import Data.Bifunctor (first, second) import Data.Foldable (toList) import Data.Functor.Const import qualified Data.Functor.Product as Product import Data.List (intersperse, intercalate) import qualified Data.Map.Strict as Map -import Data.Map.Strict (Map) +import qualified Data.Set as Set +import Data.Set (Set) import AST import AST.Pretty (ppTy) @@ -52,86 +54,139 @@ printStructDecl (StructDecl name contents comment) = printStmt :: Int -> Stmt -> ShowS printStmt indent = \case - SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" + SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";" SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") - SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";" + SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" SBlock stmts -> showString "{" . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts] . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") SIf cond b1 b2 -> - showString "if (" . printCExpr cond . showString ") " + showString "if (" . printCExpr 0 cond . showString ") " . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2) SVerbatim s -> showString s -printCExpr :: CExpr -> ShowS -printCExpr = \case +-- d values: +-- * 0: top level +-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) +-- * 2-...: various operators (see precTable) +-- * 98: inside unknown operator +-- * 99: left of a field projection +-- Unlisted operators are conservatively written with full parentheses. +printCExpr :: Int -> CExpr -> ShowS +printCExpr d = \case CELit s -> showString s CEStruct name pairs -> - showString ("(" ++ name ++ "){") - . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e - | (n, e) <- pairs]) - . showString "}" - CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name) + showParen (d >= 99) $ + showString ("(" ++ name ++ "){") + . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e + | (n, e) <- pairs]) + . showString "}" + CEProj e name -> printCExpr 99 e . showString ("." ++ name) CECall n es -> - showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")" + showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" CEBinop e1 n e2 -> - showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")" + let mprec = Map.lookup n precTable + p = maybe (-1) fst mprec -- precedence of this operator + (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments + in showParen (d > p) $ + printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2 CEIf e1 e2 e3 -> - printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3 - -repTy :: STy t -> String -repTy (STScal st) = case st of - STI32 -> "int32_t" - STI64 -> "int64_t" - STF32 -> "float" - STF64 -> "double" - STBool -> "bool" + showParen (d > 0) $ + printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 + where + precTable = Map.fromList + [("||", (2, (2, 2))) + ,("&&", (3, (3, 3))) + ,("==", (4, (5, 5))) + ,("!=", (4, (5, 5))) + ,("<", (5, (6, 6))) + ,(">", (5, (6, 6))) + ,("<=", (5, (6, 6))) + ,(">=", (5, (6, 6))) + ,("+", (6, (6, 6))) + ,("-", (6, (6, 7))) + ,("*", (7, (7, 7))) + ,("/", (7, (7, 8))) + ,("%", (7, (7, 8)))] + +repTy :: Ty -> String +repTy (TScal st) = case st of + TI32 -> "int32_t" + TI64 -> "int64_t" + TF32 -> "float" + TF64 -> "double" + TBool -> "bool" repTy t = genStructName t -genStructName :: STy t -> String +repSTy :: STy t -> String +repSTy = repTy . unSTy + +genStructName :: Ty -> String genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen t - -genStruct :: STy t -> Map String StructDecl -genStruct topty = case topty of - STNil -> - Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com) - STPair a b -> - let name = genStructName (STPair a b) - in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com) - STEither a b -> - let name = genStructName (STEither a b) -- 0 -> a, 1 -> b - in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com) - STMaybe t -> - let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just - in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com) - STArr n t -> - let name = genStructName (STArr n t) - in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com) - STScal _ -> mempty - STAccum t -> - let name = genStructName (STAccum t) - in Map.singleton name (StructDecl name (repTy t ++ " a;") com) - <> genStruct t + gen :: Ty -> String + gen TNil = "n" + gen (TPair a b) = 'P' : gen a ++ gen b + gen (TEither a b) = 'E' : gen a ++ gen b + gen (TMaybe t) = 'M' : gen t + gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t + gen (TScal st) = case st of + TI32 -> "i" + TI64 -> "j" + TF32 -> "f" + TF64 -> "d" + TBool -> "b" + gen (TAccum t) = 'C' : gen t + +genStruct :: String -> Ty -> Maybe StructDecl +genStruct name topty = case topty of + TNil -> + Just $ StructDecl name "" com + TPair a b -> + Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com + TEither a b -> -- 0 -> a, 1 -> b + Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com + TMaybe t -> -- 0 -> nothing, 1 -> just + Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com + TArr n t -> + Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com + TScal _ -> + Nothing + TAccum t -> + Just $ StructDecl name (repTy t ++ " a;") com where com = ppTy 0 topty +-- State: (already-generated (skippable) struct names, the structs in declaration order) +genStructs :: Ty -> State (Set String, Bag StructDecl) () +genStructs ty = do + let name = genStructName ty + seen <- gets ((name `Set.member`) . fst) + + case (if seen then Nothing else genStruct name ty) of + Nothing -> pure () + + Just decl -> do + -- already mark this struct as generated now, so we don't generate it twice + modify (first (Set.insert name)) + + case ty of + TNil -> pure () + TPair a b -> genStructs a >> genStructs b + TEither a b -> genStructs a >> genStructs b + TMaybe t -> genStructs t + TArr _ t -> genStructs t + TScal _ -> pure () + TAccum t -> genStructs t + + modify (second (<> pure decl)) + +genAllStructs :: Foldable t => t Ty -> [StructDecl] +genAllStructs tys = toList . snd $ execState (mapM_ genStructs tys) (mempty, mempty) + data CompState = CompState - { csStructs :: Map String StructDecl + { csStructs :: Set Ty , csStmts :: Bag Stmt , csNextId :: Int } deriving (Show) @@ -156,8 +211,9 @@ scope m = do emitStruct :: STy t -> CompM String emitStruct ty = do - modify $ \s -> s { csStructs = genStruct ty <> csStructs s } - return (genStructName ty) + let ty' = unSTy ty + modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } + return (genStructName ty') nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) @@ -166,15 +222,16 @@ compile :: SList STy env -> Ex env t -> String compile env expr = let args = nameEnv env (res, s) = runState (compile' args expr) (CompState mempty mempty 1) + structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) in ($ "") $ compose - [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s)) + [compose $ map (\sd -> printStructDecl sd . showString "\n") structs ,showString "\n" ,showString $ - repTy (typeOf expr) ++ " kernel(" ++ - intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++ + repSTy (typeOf expr) ++ " kernel(" ++ + intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repSTy t ++ " " ++ getConst n) (slistZip env args))) ++ ") {\n" ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) - ,showString (" return ") . printCExpr res . showString ";\n}\n"] + ,showString (" return ") . printCExpr 0 res . showString ";\n}\n"] compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case @@ -183,7 +240,7 @@ compile' env = \case ELet _ rhs body -> do e <- compile' env rhs var <- genName - emit $ SVarDecl True (repTy (typeOf rhs)) var e + emit $ SVarDecl True (repSTy (typeOf rhs)) var e compile' (Const var `SCons` env) body EPair _ a b -> do @@ -215,7 +272,7 @@ compile' env = \case (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (repTy (typeOf a)) retvar + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SIf e1 (stmts2 <> pure (SAsg retvar e2)) (stmts3 <> pure (SAsg retvar e3)) @@ -229,14 +286,14 @@ compile' env = \case (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (repTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (pure (SVarDecl True (repTy t1) fieldvar + (pure (SVarDecl True (repSTy t1) fieldvar (CEProj (CELit var) "a")) <> stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl True (repTy t2) fieldvar + (pure (SVarDecl True (repSTy t2) fieldvar (CEProj (CELit var) "b")) <> stmts3 <> pure (SAsg retvar e3)))) @@ -258,12 +315,12 @@ compile' env = \case (e2, stmts2) <- scope $ compile' env a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (repTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) (stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl True (repTy (typeOf b)) fieldvar + (pure (SVarDecl True (repSTy (typeOf b)) fieldvar (CEProj (CELit var) "a")) <> stmts3 <> pure (SAsg retvar e3)))) @@ -332,7 +389,7 @@ compileOpGeneral op e1 = do let unary cop = return @(State CompState) $ CECall cop [e1] let binary cop = do name <- genName - emit $ SVarDecl True (repTy (opt1 op)) name e1 + emit $ SVarDecl True (repSTy (opt1 op)) name e1 return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") case op of OAdd _ -> binary "+" |