diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 11:39:37 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 11:39:37 +0100 |
commit | 54762c98901b21468fa9ff4205107360c3096cd3 (patch) | |
tree | a765a9af9a19eaae5c4857f9ff4a9d3188708e52 | |
parent | 0fffb5731271a551afcf08878cb021ead8dd1dae (diff) |
compile: Compile constant array literals
-rw-r--r-- | src/Compile.hs | 163 | ||||
-rw-r--r-- | src/Compile/Exec.hs | 2 | ||||
-rw-r--r-- | src/Language.hs | 2 |
3 files changed, 94 insertions, 73 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 564f697..2a23561 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -4,6 +4,7 @@ {-# LANGUAGE TypeApplications #-} module Compile (compile) where +import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS @@ -19,6 +20,10 @@ import Data.Set (Set) import Data.Some import qualified Data.Vector as V import Foreign +import System.IO (hPutStrLn, stderr) + +import Prelude hiding ((^)) +import qualified Prelude import Array import AST @@ -28,6 +33,11 @@ import Data import Interpreter.Rep +-- :m *Example Compile AST.UnMonoid +-- :seti -XOverloadedLabels -XGADTs +-- (($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i)) + + -- In shape and index arrays, the innermost dimension is on the right (last index). -- TODO: array lifetimes in C? @@ -36,7 +46,9 @@ import Interpreter.Rep compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) compile = \env expr -> do - lib <- buildKernel (compileToString env expr) ["kernel"] + let source = compileToString env expr + hPutStrLn stderr $ "Generated C source: <<<\n" ++ source ++ ">>>" + lib <- buildKernel source ["kernel"] let arg_metrics = reverse (unSList metricsSTy env) (arg_offsets, result_offset) = computeStructOffsets arg_metrics @@ -85,7 +97,7 @@ data CExpr printStructDecl :: StructDecl -> ShowS printStructDecl (StructDecl name contents comment) = showString "typedef struct { " . showString contents . showString " } " . showString name - . showString ("; // " ++ comment) + . showString ";" . (if null comment then id else showString (" // " ++ comment)) printStmt :: Int -> Stmt -> ShowS printStmt indent = \case @@ -187,7 +199,8 @@ genStruct name topty = case topty of TMaybe t -> -- 0 -> nothing, 1 -> just [StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com] TArr n t -> - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " *a;") com + -- The buffer is trailed by a VLA for the actual array data. + [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " a[];") "" ,StructDecl name (name ++ "_buf *buf;") com] TScal _ -> [] @@ -289,7 +302,7 @@ compileToString env expr = ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n" ,showString "void kernel(void *data) {\n" -- Some code here assumes that we're on a 64-bit system, so let's check that - ,showString " if (sizeof(void*) != 8) { abort(); }\n" + ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { abort(); }\n" ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ concat (map (\((arg, typ), off, idx) -> "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" @@ -372,11 +385,14 @@ deserialise topty ptr off = then return Nothing else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do - _ <- error "TODO deserialisation of arrays is wrong after refcount introduction" - sh <- peekShape ptr off n - let off1 = off + 8 * fromSNat n + bufptr <- peekByteOff @(Ptr ()) ptr off + sh <- peekShape bufptr 0 n + refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + let off1 = 8 * fromSNat n + 8 eltsz = sizeofSTy t - Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t ptr (off1 + i * eltsz)) + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) + when (refc < 2 ^ 62) $ free bufptr + return arr STScal sty -> case sty of STI32 -> peekByteOff @Int32 ptr off STI64 -> peekByteOff @Int64 ptr off @@ -438,7 +454,7 @@ compile' env = \case EVar _ t i -> do let Const var = slistIdx env i case t of - STArr{} -> return $ CELit ("(++" ++ var ++ ".refc, " ++ var ++ ")") + STArr{} -> return $ CELit ("(++" ++ var ++ "->buf.refc, " ++ var ++ ")") _ -> return $ CELit var ELet _ rhs body -> do @@ -446,7 +462,7 @@ compile' env = \case var <- genName emit $ SVarDecl True (repSTy (typeOf rhs)) var e rete <- compile' (Const var `SCons` env) body - releaseVarAlways (typeOf rhs) var + incrementVarAlways Decrement (typeOf rhs) var return rete EPair _ a b -> do @@ -458,7 +474,7 @@ compile' env = \case EFst _ e -> do let STPair _ t2 = typeOf e e' <- compile' env e - case releaseVar t2 of + case incrementVar Decrement t2 of Nothing -> return $ CEProj e' "a" Just f -> do var <- genName emit $ SVarDecl True (repSTy (typeOf e)) var e' @@ -468,7 +484,7 @@ compile' env = \case ESnd _ e -> do let STPair t1 _ = typeOf e e' <- compile' env e - case releaseVar t1 of + case incrementVar Decrement t1 of Nothing -> return $ CEProj e' "b" Just f -> do var <- genName emit $ SVarDecl True (repSTy (typeOf e)) var e' @@ -507,8 +523,8 @@ compile' env = \case -- I know those are not variable names, but it's fine, probably (e2, stmts2) <- scope $ compile' (Const (var ++ ".a") `SCons` env) a (e3, stmts3) <- scope $ compile' (Const (var ++ ".b") `SCons` env) b - ((), stmtsRel1) <- scope $ releaseVarAlways t1 (var ++ ".a") - ((), stmtsRel2) <- scope $ releaseVarAlways t2 (var ++ ".b") + ((), stmtsRel1) <- scope $ incrementVarAlways Decrement t1 (var ++ ".a") + ((), stmtsRel2) <- scope $ incrementVarAlways Decrement t2 (var ++ ".b") retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) @@ -536,7 +552,7 @@ compile' env = \case var <- genName (e2, stmts2) <- scope $ compile' env a (e3, stmts3) <- scope $ compile' (Const (var ++ ".a") `SCons` env) b - ((), stmtsRel) <- scope $ releaseVarAlways t (var ++ ".a") + ((), stmtsRel) <- scope $ incrementVarAlways Decrement t (var ++ ".a") retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) @@ -550,18 +566,14 @@ compile' env = \case EConstArr _ n t (Array sh vec) -> do strname <- emitStruct (STArr n (STScal t)) - tldname <- genName' "carray" - emitTLD $ "static const " ++ repSTy (STScal t) ++ " " ++ - tldname ++ "[" ++ show (shapeSize sh) ++ "] = {" ++ - intercalate "," (map (compileScal False t) (toList vec)) ++ - "};" + tldname <- genName' "carraybuf" -- Give it a refcount of _half_ the size_t max, so that it can be -- incremented and decremented at will and will "never" reach anything -- where something happens - emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ "_buf = " ++ + emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++ "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ - ".refc = SIZE_MAX/2, .a = " ++ tldname ++ "};" - return (CEStruct strname [("buf", CEAddrOf (CELit (tldname ++ "_buf")))]) + ".refc = (size_t)1<<63, .a = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) -- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b) @@ -614,55 +626,61 @@ compile' env = \case _ -> error "Compile: not implemented" --- | Decrement reference counts in the components of the given variable. -releaseVar :: STy a -> Maybe (String -> CompM ()) -releaseVar ty = - let tree = makeReleaseTree ty - in case tree of RTNoop -> Nothing - _ -> Just $ \var -> releaseVar' var tree - -releaseVarAlways :: STy a -> String -> CompM () -releaseVarAlways ty var = maybe (pure ()) ($ var) (releaseVar ty) - -data ReleaseTree = RTArray -- ^ we've arrived at an array we need to decrement the refcount of - | RTNoop -- ^ don't do anything here - | RTProj String ReleaseTree -- ^ descend one field deeper - | RTCondTag ReleaseTree ReleaseTree -- ^ if tag is 0, first; if 1, second - | RTBoth ReleaseTree ReleaseTree -- ^ do both these paths - -smartRTProj :: String -> ReleaseTree -> ReleaseTree -smartRTProj _ RTNoop = RTNoop -smartRTProj field t = RTProj field t - -smartRTCondTag :: ReleaseTree -> ReleaseTree -> ReleaseTree -smartRTCondTag RTNoop RTNoop = RTNoop -smartRTCondTag t t' = RTCondTag t t' - -smartRTBoth :: ReleaseTree -> ReleaseTree -> ReleaseTree -smartRTBoth RTNoop t = t -smartRTBoth t RTNoop = t -smartRTBoth t t' = RTBoth t t' - -makeReleaseTree :: STy a -> ReleaseTree -makeReleaseTree STNil = RTNoop -makeReleaseTree (STPair a b) = smartRTBoth (smartRTProj "a" (makeReleaseTree a)) - (smartRTProj "b" (makeReleaseTree b)) -makeReleaseTree (STEither a b) = smartRTCondTag (smartRTProj "a" (makeReleaseTree a)) - (smartRTProj "b" (makeReleaseTree b)) -makeReleaseTree (STMaybe t) = smartRTCondTag RTNoop (makeReleaseTree t) -makeReleaseTree (STArr _ _) = RTArray -makeReleaseTree (STScal _) = RTNoop -makeReleaseTree (STAccum _) = RTNoop - -releaseVar' :: String -> ReleaseTree -> CompM () -releaseVar' path RTArray = emit $ SVerbatim (path ++ "--;") -releaseVar' _ RTNoop = pure () -releaseVar' path (RTProj field t) = releaseVar' (path ++ "." ++ field) t -releaseVar' path (RTCondTag t1 t2) = do - ((), stmts1) <- scope $ releaseVar' path t1 - ((), stmts2) <- scope $ releaseVar' path t2 +data Increment = Increment | Decrement + deriving (Show) + +-- | Increment reference counts in the components of the given variable. +incrementVar :: Increment -> STy a -> Maybe (String -> CompM ()) +incrementVar inc ty = + let tree = makeArrayTree ty + in case tree of ATNoop -> Nothing + _ -> Just $ \var -> incrementVar' inc var tree + +incrementVarAlways :: Increment -> STy a -> String -> CompM () +incrementVarAlways inc ty var = maybe (pure ()) ($ var) (incrementVar inc ty) + +data ArrayTree = ATArray -- ^ we've arrived at an array we need to decrement the refcount of + | ATNoop -- ^ don't do anything here + | ATProj String ArrayTree -- ^ descend one field deeper + | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATBoth ArrayTree ArrayTree -- ^ do both these paths + +smartATProj :: String -> ArrayTree -> ArrayTree +smartATProj _ ATNoop = ATNoop +smartATProj field t = ATProj field t + +smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree +smartATCondTag ATNoop ATNoop = ATNoop +smartATCondTag t t' = ATCondTag t t' + +smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree +smartATBoth ATNoop t = t +smartATBoth t ATNoop = t +smartATBoth t t' = ATBoth t t' + +makeArrayTree :: STy a -> ArrayTree +makeArrayTree STNil = ATNoop +makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) + (smartATProj "b" (makeArrayTree b)) +makeArrayTree (STEither a b) = smartATCondTag (smartATProj "a" (makeArrayTree a)) + (smartATProj "b" (makeArrayTree b)) +makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t) +makeArrayTree (STArr _ _) = ATArray +makeArrayTree (STScal _) = ATNoop +makeArrayTree (STAccum _) = ATNoop + +incrementVar' :: Increment -> String -> ArrayTree -> CompM () +incrementVar' inc path ATArray = + let op = case inc of Increment -> "++" + Decrement -> "--" + in emit $ SVerbatim (path ++ "->buf.refc" ++ op ++ ";") +incrementVar' _ _ ATNoop = pure () +incrementVar' inc path (ATProj field t) = incrementVar' inc (path ++ "." ++ field) t +incrementVar' inc path (ATCondTag t1 t2) = do + ((), stmts1) <- scope $ incrementVar' inc path t1 + ((), stmts2) <- scope $ incrementVar' inc path t2 emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2) -releaseVar' path (RTBoth t1 t2) = releaseVar' path t1 >> releaseVar' path t2 +incrementVar' inc path (ATBoth t1 t2) = incrementVar' inc path t1 >> incrementVar' inc path t2 compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr @@ -721,3 +739,6 @@ compileScal pedantic typ x = case typ of compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id + +(^) :: Num a => a -> Int -> a +(^) = (Prelude.^) diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index 83fcdad..7c5cb79 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -28,7 +28,7 @@ buildKernel csource funnames = do path <- mkdtemp template let outso = path ++ "/out.so" - let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-"] + let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-", "-Wall", "-Wextra", "-Wno-unused-parameter"] _ <- readProcess "gcc" args csource hPutStrLn stderr $ "[chad] loading kernel " ++ path diff --git a/src/Language.hs b/src/Language.hs index a7737e0..70cc4f9 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -69,7 +69,7 @@ inr = NEInr knownTy case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 -constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) constArr_ x = let ty = knownScalTy in case scalRepIsShow ty of |