{-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE TypeApplications #-} module Compile where import Control.Monad.Trans.State.Strict import Data.Foldable (toList) import Data.Functor.Const import Data.List (intersperse, intercalate) import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import AST import AST.Pretty (ppTy) import Data -- In shape and index arrays, the innermost dimension is on the right (last index). data StructDecl = StructDecl String -- ^ name String -- ^ contents String -- ^ comment deriving (Show) data Stmt = SVarDecl String String CExpr -- ^ type, variable name, right-hand side | SVarDeclUninit String String -- ^ type, variable name (no initialiser) | SAsg String CExpr -- ^ variable name, right-hand side | SBlock [Stmt] | SIf CExpr [Stmt] [Stmt] | SVerbatim String deriving (Show) data CExpr = CELit String | CEStruct String [(String, CExpr)] | CEProj CExpr String | CECall String [CExpr] | CEBinop CExpr String CExpr | CEIf CExpr CExpr CExpr deriving (Show) printStructDecl :: StructDecl -> ShowS printStructDecl (StructDecl name contents comment) = showString "typedef struct { " . showString contents . showString " } " . showString name . showString ("; // " ++ comment) printStmt :: Int -> Stmt -> ShowS printStmt indent = \case SVarDecl typ name rhs -> showString (typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";" SBlock stmts -> showString "{" . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts] . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") SIf cond b1 b2 -> showString "if (" . printCExpr cond . showString ") " . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2) SVerbatim s -> showString s printCExpr :: CExpr -> ShowS printCExpr = \case CELit s -> showString s CEStruct name pairs -> showString ("(" ++ name ++ "){") . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e | (n, e) <- pairs]) . showString "}" CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name) CECall n es -> showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")" CEBinop e1 n e2 -> showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")" CEIf e1 e2 e3 -> printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3 repTy :: STy t -> String repTy (STScal st) = case st of STI32 -> "int32_t" STI64 -> "int64_t" STF32 -> "float" STF64 -> "double" STBool -> "bool" repTy t = genStructName t genStructName :: STy t -> String genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. gen :: STy t -> String gen STNil = "n" gen (STPair a b) = 'p' : gen a ++ gen b gen (STEither a b) = 'e' : gen a ++ gen b gen (STMaybe t) = 'm' : gen t gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t gen (STScal st) = case st of STI32 -> "i4" STI64 -> "i8" STF32 -> "f4" STF64 -> "f8" STBool -> "b" gen (STAccum t) = 'C' : gen t genStruct :: STy t -> Map String StructDecl genStruct topty = case topty of STNil -> Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com) STPair a b -> let name = genStructName (STPair a b) in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;") com) STEither a b -> let name = genStructName (STEither a b) -- 0 -> a, 1 -> b in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };") com) STMaybe t -> let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;") com) STArr n t -> let name = genStructName (STArr n t) in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;") com) STScal _ -> mempty STAccum t -> let name = genStructName (STAccum t) in Map.singleton name (StructDecl name (genStructName t ++ " a;") com) <> genStruct t where com = ppTy 0 topty data CompState = CompState { csStructs :: Map String StructDecl , csStmts :: Bag Stmt , csNextId :: Int } deriving (Show) type CompM a = State CompState a genId :: CompM Int genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) genName :: CompM String genName = ('x' :) . show <$> genId emit :: Stmt -> CompM () emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } scope :: CompM a -> CompM (a, [Stmt]) scope m = do stmts <- state $ \s -> (csStmts s, s { csStmts = mempty }) res <- m innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts }) return (res, toList innerStmts) emitStruct :: STy t -> CompM String emitStruct ty = do modify $ \s -> s { csStructs = genStruct ty <> csStructs s } return (genStructName ty) compile :: SList (Const String) env -> Ex env t -> String compile env expr = let (res, s) = runState (compile' env expr) (CompState mempty mempty 1) in ($ "") $ compose [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s)) ,showString "\n" ,showString (genStructName (typeOf expr) ++ " kernel(" ++ intercalate ", " (reverse (unSList getConst env)) ++ ") {\n") ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) ,showString (" return ") . printCExpr res . showString ";\n}\n"] compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case EVar _ _ i -> return $ CELit (getConst (slistIdx env i)) ELet _ rhs body -> do e <- compile' env rhs var <- genName emit $ SVarDecl (genStructName (typeOf rhs)) var e compile' (Const var `SCons` env) body EPair _ a b -> do name <- emitStruct (STPair (typeOf a) (typeOf b)) e1 <- compile' env a e2 <- compile' env b return $ CEStruct name [("a", e1), ("b", e2)] EFst _ e -> CEProj <$> compile' env e <*> pure "a" ESnd _ e -> CEProj <$> compile' env e <*> pure "b" ENil _ -> do name <- emitStruct STNil return $ CEStruct name [] EInl _ t e -> do name <- emitStruct (STEither (typeOf e) t) e1 <- compile' env e return $ CEStruct name [("tag", CELit "0"), ("a", e1)] EInr _ t e -> do name <- emitStruct (STEither t (typeOf e)) e2 <- compile' env e return $ CEStruct name [("tag", CELit "1"), ("b", e2)] ECase _ e a b -> do let STEither t1 t2 = typeOf e e1 <- compile' env e var <- genName fieldvar <- genName (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName emit $ SVarDeclUninit (genStructName (typeOf a)) retvar emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) (pure (SVarDecl (genStructName t1) fieldvar (CEProj (CELit var) "a")) <> stmts2 <> pure (SAsg retvar e2)) (pure (SVarDecl (genStructName t2) fieldvar (CEProj (CELit var) "b")) <> stmts3 <> pure (SAsg retvar e3)))) return (CELit retvar) ENothing _ t -> do name <- emitStruct (STMaybe t) return $ CEStruct name [("tag", CELit "0")] EJust _ e -> do name <- emitStruct (STMaybe (typeOf e)) e1 <- compile' env e return $ CEStruct name [("tag", CELit "1"), ("a", e1)] EMaybe _ a b e -> do e1 <- compile' env e var <- genName fieldvar <- genName (e2, stmts2) <- scope $ compile' env a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName emit $ SVarDeclUninit (genStructName (typeOf a)) retvar emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) (stmts2 <> pure (SAsg retvar e2)) (pure (SVarDecl (genStructName (typeOf b)) fieldvar (CEProj (CELit var) "a")) <> stmts3 <> pure (SAsg retvar e3)))) return (CELit retvar) EConstArr _ n t arr -> do name <- emitStruct (STArr n (STScal t)) error "TODO" EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b) EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) ESum1Inner _ e -> error "TODO" -- ESum1Inner ext (compile' e) EUnit _ e -> error "TODO" -- EUnit ext (compile' e) EReplicate1Inner _ a b -> error "TODO" -- EReplicate1Inner ext (compile' a) (compile' b) EMaximum1Inner _ e -> error "TODO" -- EMaximum1Inner ext (compile' e) EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e) EConst _ t x -> case t of STI32 -> return $ CELit $ "(int32_t)" ++ show x STI64 -> return $ CELit $ "(int64_t)" ++ show x STF32 -> return $ CELit $ show x ++ "f" STF64 -> return $ CELit $ show x STBool -> return $ CELit $ if x then "true" else "false" EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e) EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) EIdx _ a b -> error "TODO" -- EIdx ext (compile' a) (compile' b) EShape _ e -> error "TODO" -- EShape ext (compile' e) EOp _ op e -> do e1 <- compile' env e let unary cop = return @(State CompState) $ CECall cop [e1] let binary cop = do name <- genName emit $ SVarDecl (genStructName (typeOf e)) name e1 return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") case op of OAdd _ -> binary "+" OMul _ -> binary "*" ONeg _ -> unary "-" OLt _ -> binary "<" OLe _ -> binary "<=" OEq _ -> binary "==" ONot -> unary "!" OAnd -> binary "&&" OOr -> binary "||" OIf -> do name <- emitStruct (STEither STNil STNil) _ <- emitStruct STNil return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) (CEStruct name [("tag", CELit "1")]) ORound64 -> unary "(int64_t)round" -- ew OToFl64 -> unary "(double)" ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 OExp STF32 -> unary "expf" OExp STF64 -> unary "exp" OLog STF32 -> unary "logf" OLog STF64 -> unary "log" OIDiv _ -> binary "/" ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) EWith a b -> error "TODO" -- EWith (compile' a) (compile' b) EAccum n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) EError t s -> do name <- emitStruct t -- using 'show' here is wrong, but it's good enough for me. emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);" return $ CEStruct name [] EZero{} -> error "Compile: monoid operations should have been eliminated" EPlus{} -> error "Compile: monoid operations should have been eliminated" EOneHot{} -> error "Compile: monoid operations should have been eliminated" compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @(State CompState) $ CECall cop [e1] let binary cop = do name <- genName emit $ SVarDecl (genStructName (opt1 op)) name e1 return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") case op of OAdd _ -> binary "+" OMul _ -> binary "*" ONeg _ -> unary "-" OLt _ -> binary "<" OLe _ -> binary "<=" OEq _ -> binary "==" ONot -> unary "!" OAnd -> binary "&&" OOr -> binary "||" OIf -> do name <- emitStruct (STEither STNil STNil) _ <- emitStruct STNil return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) (CEStruct name [("tag", CELit "1")]) ORound64 -> unary "(int64_t)round" -- ew OToFl64 -> unary "(double)" ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 OExp STF32 -> unary "expf" OExp STF64 -> unary "exp" OLog STF32 -> unary "logf" OLog STF64 -> unary "log" OIDiv _ -> binary "/" compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr compileOpPair op e1 e2 = do let binary cop = return @(State CompState) $ CEBinop e1 cop e2 case op of OAdd _ -> binary "+" OMul _ -> binary "*" OLt _ -> binary "<" OLe _ -> binary "<=" OEq _ -> binary "==" OAnd -> binary "&&" OOr -> binary "||" OIDiv _ -> binary "/" _ -> error "compileOpPair: got unary operator" compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id