summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-02 22:22:11 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-02 22:22:11 +0100
commit15cdfa67937ce20fc90ade59437be0a9e4d7a481 (patch)
treed1a21646d887eb36bb115c45a47636f5203683ce
parentdccb5f2a0e92a568961e60e3e2ba3dfb4316c663 (diff)
Compile: Support EShape and EBuild
-rw-r--r--src/Compile.hs107
1 files changed, 95 insertions, 12 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 4a6b9f9..8c77cd6 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -4,7 +4,7 @@
{-# LANGUAGE TypeApplications #-}
module Compile (compile) where
-import Control.Monad (when)
+import Control.Monad (forM_, when, replicateM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
@@ -36,20 +36,19 @@ import Interpreter.Rep
{-
:m *Example Compile AST.UnMonoid
:seti -XOverloadedLabels -XGADTs
-(($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i))
+let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i :: Double) in (($ SCons (Value array) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ #x ! pair nil (round_ (#x ! pair nil 3))
+(($ SNil) =<<) $ compile knownEnv $ fromNamed $ body $ build2 5 3 (#i :-> #j :-> 10 * #i + #j)
-}
-- In shape and index arrays, the innermost dimension is on the right (last index).
--- TODO: array lifetimes in C?
-
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
let source = compileToString env expr
- hPutStrLn stderr $ "Generated C source: <<<\n" ++ source ++ ">>>"
+ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
lib <- buildKernel source ["kernel"]
let arg_metrics = reverse (unSList metricsSTy env)
@@ -83,6 +82,7 @@ data Stmt
| SAsg String CExpr -- ^ variable name, right-hand side
| SBlock (Bag Stmt)
| SIf CExpr (Bag Stmt) (Bag Stmt)
+ | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>}
| SVerbatim String -- ^ no implicit ';', just printed as-is
deriving (Show)
@@ -95,6 +95,7 @@ data CExpr
| CECall String [CExpr] -- ^ function(arg1, ..., argn)
| CEBinop CExpr String CExpr -- ^ expr + expr
| CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
+ | CECast String CExpr -- ^ (<type)<expr>
deriving (Show)
printStructDecl :: StructDecl -> ShowS
@@ -104,7 +105,7 @@ printStructDecl (StructDecl name contents comment) =
printStmt :: Int -> Stmt -> ShowS
printStmt indent = \case
- SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
+ SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
SBlock stmts
@@ -117,6 +118,11 @@ printStmt indent = \case
showString "if (" . printCExpr 0 cond . showString ") "
. printStmt indent (SBlock b1)
. (if null b2 then id else showString " else " . printStmt indent (SBlock b2))
+ SLoop typ name e1 e2 stmts ->
+ showString ("for (" ++ typ ++ " " ++ name ++ " = ")
+ . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2
+ . showString ("; " ++ name ++ "++) ")
+ . printStmt indent (SBlock stmts)
SVerbatim s -> showString s
-- d values:
@@ -150,13 +156,15 @@ printCExpr d = \case
CEIf e1 e2 e3 ->
showParen (d > 0) $
printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3
+ CECast typ e ->
+ showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e
where
precTable = Map.fromList
[("||", (2, (2, 2)))
,("&&", (3, (3, 3)))
,("==", (4, (5, 5)))
,("!=", (4, (5, 5)))
- ,("<", (5, (6, 6)))
+ ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop
,(">", (5, (6, 6)))
,("<=", (5, (6, 6)))
,(">=", (5, (6, 6)))
@@ -585,7 +593,44 @@ compile' env = \case
".refc = (size_t)1<<63, .a = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
- -- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b)
+ EBuild _ n esh efun -> do
+ let arrty = STArr n (typeOf efun)
+ strname <- emitStruct arrty
+
+ shname <- genName' "sh"
+ emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh
+ shsizename <- genName' "shsz"
+ emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname)
+
+ bufpname <- genName' "bufp"
+ emit $ SVarDecl True (repSTy arrty ++ "_buf *") bufpname
+ (CECall "malloc" [CEBinop (CELit (show (fromSNat n * 8 + 8)))
+ "+"
+ (CEBinop (CELit shsizename)
+ "*" (CELit (show (sizeofSTy (typeOf efun)))))])
+ forM_ (zip (compileShapeTupIntoArray n shname) [0::Int ..]) $ \(rhs, i) ->
+ emit $ SAsg (bufpname ++ "->sh[" ++ show i ++ "]") rhs
+ emit $ SAsg (bufpname ++ "->refc") (CELit "1")
+
+ arrname <- genName' "arr"
+ emit $ SVarDecl True (repSTy arrty) arrname (CEStruct strname [("buf", CELit bufpname)])
+
+ idxargname <- genName' "ix"
+ (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
+
+ linivar <- genName' "li"
+ ivars <- replicateM (fromSNat n) (genName' "i")
+ emit $ SBlock $
+ pure (SVarDecl False "size_t" linivar (CELit "0"))
+ <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
+ (CECast (repSTy tIx) (CEIndex (CELit (bufpname ++ "->sh")) (CELit (show dimidx))))
+ | (ivar, dimidx) <- zip ivars [0::Int ..]]
+ (pure (SVarDecl True (repSTy (typeOf esh)) idxargname
+ (shapeTupFromLitVars n ivars))
+ <> BList funstmts
+ <> pure (SAsg (bufpname ++ "->a[" ++ linivar ++ "++]") funretval))
+
+ return (CELit arrname)
-- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c)
@@ -607,16 +652,25 @@ compile' env = \case
EIdx _ earr eidx -> do
let STArr n t = typeOf earr
- arrname <- genName
- idxname <- genName
+ arrname <- genName' "ixarr"
+ idxname <- genName' "ixix"
emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr
emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx
- resname <- genName
+ resname <- genName' "ixres"
emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname))
incrementVarAlways Decrement (STArr n t) arrname
return (CELit resname)
- -- EShape _ e -> error "TODO" -- EShape ext (compile' e)
+ EShape _ e -> do
+ let STArr n _ = typeOf e
+ t = tTup (sreplicate n tIx)
+ _ <- emitStruct t
+ name <- genName
+ emit . SVarDecl True (repSTy (typeOf e)) name =<< compile' env e
+ resname <- genName
+ emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
+ incrementVarAlways Decrement (typeOf e) name
+ return (CELit resname)
EOp _ op (EPair _ e1 e2) -> do
e1' <- compile' env e1
@@ -717,6 +771,35 @@ toLinearIdx (SS n) arrvar idxvar =
-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
-- _
+compileShapeQuery :: SNat n -> String -> CExpr
+compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
+compileShapeQuery (SS n) var =
+ CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
+ [("a", compileShapeQuery n var)
+ ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
+
+compileShapeSize :: SNat n -> String -> CExpr
+compileShapeSize SZ _ = CELit "0"
+compileShapeSize (SS SZ) var = CELit (var ++ ".b")
+compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b"))
+
+compileShapeTupIntoArray :: SNat n -> String -> [CExpr]
+compileShapeTupIntoArray = \n var -> map CELit (toList (go n var))
+ where
+ go :: SNat n -> String -> Bag String
+ go SZ _ = mempty
+ go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b")
+
+-- | Takes variable names with the innermost dimension on the right.
+shapeTupFromLitVars :: SNat n -> [String] -> CExpr
+shapeTupFromLitVars = \n -> go n . reverse
+ where
+ -- takes variables with the innermost dimension at the _head_
+ go :: SNat n -> [String] -> CExpr
+ go SZ [] = CEStruct (repSTy STNil) []
+ go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
+ go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
let unary cop = return @(State CompState) $ CECall cop [e1]