diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 22:22:11 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 22:22:11 +0100 |
commit | 15cdfa67937ce20fc90ade59437be0a9e4d7a481 (patch) | |
tree | d1a21646d887eb36bb115c45a47636f5203683ce | |
parent | dccb5f2a0e92a568961e60e3e2ba3dfb4316c663 (diff) |
Compile: Support EShape and EBuild
-rw-r--r-- | src/Compile.hs | 107 |
1 files changed, 95 insertions, 12 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 4a6b9f9..8c77cd6 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -4,7 +4,7 @@ {-# LANGUAGE TypeApplications #-} module Compile (compile) where -import Control.Monad (when) +import Control.Monad (forM_, when, replicateM) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS @@ -36,20 +36,19 @@ import Interpreter.Rep {- :m *Example Compile AST.UnMonoid :seti -XOverloadedLabels -XGADTs -(($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i)) +let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i :: Double) in (($ SCons (Value array) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ #x ! pair nil (round_ (#x ! pair nil 3)) +(($ SNil) =<<) $ compile knownEnv $ fromNamed $ body $ build2 5 3 (#i :-> #j :-> 10 * #i + #j) -} -- In shape and index arrays, the innermost dimension is on the right (last index). --- TODO: array lifetimes in C? - compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) compile = \env expr -> do let source = compileToString env expr - hPutStrLn stderr $ "Generated C source: <<<\n" ++ source ++ ">>>" + hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" lib <- buildKernel source ["kernel"] let arg_metrics = reverse (unSList metricsSTy env) @@ -83,6 +82,7 @@ data Stmt | SAsg String CExpr -- ^ variable name, right-hand side | SBlock (Bag Stmt) | SIf CExpr (Bag Stmt) (Bag Stmt) + | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>} | SVerbatim String -- ^ no implicit ';', just printed as-is deriving (Show) @@ -95,6 +95,7 @@ data CExpr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr + | CECast String CExpr -- ^ (<type)<expr> deriving (Show) printStructDecl :: StructDecl -> ShowS @@ -104,7 +105,7 @@ printStructDecl (StructDecl name contents comment) = printStmt :: Int -> Stmt -> ShowS printStmt indent = \case - SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";" + SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";" SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" SBlock stmts @@ -117,6 +118,11 @@ printStmt indent = \case showString "if (" . printCExpr 0 cond . showString ") " . printStmt indent (SBlock b1) . (if null b2 then id else showString " else " . printStmt indent (SBlock b2)) + SLoop typ name e1 e2 stmts -> + showString ("for (" ++ typ ++ " " ++ name ++ " = ") + . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2 + . showString ("; " ++ name ++ "++) ") + . printStmt indent (SBlock stmts) SVerbatim s -> showString s -- d values: @@ -150,13 +156,15 @@ printCExpr d = \case CEIf e1 e2 e3 -> showParen (d > 0) $ printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 + CECast typ e -> + showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e where precTable = Map.fromList [("||", (2, (2, 2))) ,("&&", (3, (3, 3))) ,("==", (4, (5, 5))) ,("!=", (4, (5, 5))) - ,("<", (5, (6, 6))) + ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop ,(">", (5, (6, 6))) ,("<=", (5, (6, 6))) ,(">=", (5, (6, 6))) @@ -585,7 +593,44 @@ compile' env = \case ".refc = (size_t)1<<63, .a = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) - -- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b) + EBuild _ n esh efun -> do + let arrty = STArr n (typeOf efun) + strname <- emitStruct arrty + + shname <- genName' "sh" + emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh + shsizename <- genName' "shsz" + emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname) + + bufpname <- genName' "bufp" + emit $ SVarDecl True (repSTy arrty ++ "_buf *") bufpname + (CECall "malloc" [CEBinop (CELit (show (fromSNat n * 8 + 8))) + "+" + (CEBinop (CELit shsizename) + "*" (CELit (show (sizeofSTy (typeOf efun)))))]) + forM_ (zip (compileShapeTupIntoArray n shname) [0::Int ..]) $ \(rhs, i) -> + emit $ SAsg (bufpname ++ "->sh[" ++ show i ++ "]") rhs + emit $ SAsg (bufpname ++ "->refc") (CELit "1") + + arrname <- genName' "arr" + emit $ SVarDecl True (repSTy arrty) arrname (CEStruct strname [("buf", CELit bufpname)]) + + idxargname <- genName' "ix" + (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun + + linivar <- genName' "li" + ivars <- replicateM (fromSNat n) (genName' "i") + emit $ SBlock $ + pure (SVarDecl False "size_t" linivar (CELit "0")) + <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") + (CECast (repSTy tIx) (CEIndex (CELit (bufpname ++ "->sh")) (CELit (show dimidx)))) + | (ivar, dimidx) <- zip ivars [0::Int ..]] + (pure (SVarDecl True (repSTy (typeOf esh)) idxargname + (shapeTupFromLitVars n ivars)) + <> BList funstmts + <> pure (SAsg (bufpname ++ "->a[" ++ linivar ++ "++]") funretval)) + + return (CELit arrname) -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) @@ -607,16 +652,25 @@ compile' env = \case EIdx _ earr eidx -> do let STArr n t = typeOf earr - arrname <- genName - idxname <- genName + arrname <- genName' "ixarr" + idxname <- genName' "ixix" emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx - resname <- genName + resname <- genName' "ixres" emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname)) incrementVarAlways Decrement (STArr n t) arrname return (CELit resname) - -- EShape _ e -> error "TODO" -- EShape ext (compile' e) + EShape _ e -> do + let STArr n _ = typeOf e + t = tTup (sreplicate n tIx) + _ <- emitStruct t + name <- genName + emit . SVarDecl True (repSTy (typeOf e)) name =<< compile' env e + resname <- genName + emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) + incrementVarAlways Decrement (typeOf e) name + return (CELit resname) EOp _ op (EPair _ e1 e2) -> do e1' <- compile' env e1 @@ -717,6 +771,35 @@ toLinearIdx (SS n) arrvar idxvar = -- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) -- _ +compileShapeQuery :: SNat n -> String -> CExpr +compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] +compileShapeQuery (SS n) var = + CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) + [("a", compileShapeQuery n var) + ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] + +compileShapeSize :: SNat n -> String -> CExpr +compileShapeSize SZ _ = CELit "0" +compileShapeSize (SS SZ) var = CELit (var ++ ".b") +compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b")) + +compileShapeTupIntoArray :: SNat n -> String -> [CExpr] +compileShapeTupIntoArray = \n var -> map CELit (toList (go n var)) + where + go :: SNat n -> String -> Bag String + go SZ _ = mempty + go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b") + +-- | Takes variable names with the innermost dimension on the right. +shapeTupFromLitVars :: SNat n -> [String] -> CExpr +shapeTupFromLitVars = \n -> go n . reverse + where + -- takes variables with the innermost dimension at the _head_ + go :: SNat n -> [String] -> CExpr + go SZ [] = CEStruct (repSTy STNil) [] + go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] + go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @(State CompState) $ CECall cop [e1] |