diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 00:21:36 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 00:21:36 +0100 |
commit | 0fffb5731271a551afcf08878cb021ead8dd1dae (patch) | |
tree | fd4fff6010db0fb63515930708b3d8bfd234c367 | |
parent | 0ebdcff2aa06ee95f95597f2984e2cd335899d37 (diff) |
compile: WIP reference-counted arrays
-rw-r--r-- | src/Array.hs | 7 | ||||
-rw-r--r-- | src/Compile.hs | 195 | ||||
-rw-r--r-- | src/Compile/Exec.hs | 2 | ||||
-rw-r--r-- | src/Data.hs | 3 |
4 files changed, 159 insertions, 48 deletions
diff --git a/src/Array.hs b/src/Array.hs index 82c3f31..059600f 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -63,6 +63,13 @@ emptyShape (SS m) = emptyShape m `ShCons` 0 enumShape :: Shape n -> [Index n] enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] +shapeToList :: Shape n -> [Int] +shapeToList = go [] + where + go :: [Int] -> Shape n -> [Int] + go suff ShNil = suff + go suff (sh `ShCons` n) = go (n:suff) sh + -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) diff --git a/src/Compile.hs b/src/Compile.hs index 582a1df..564f697 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -67,8 +67,8 @@ data Stmt = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side | SVarDeclUninit String String -- ^ type, variable name (no initialiser) | SAsg String CExpr -- ^ variable name, right-hand side - | SBlock [Stmt] - | SIf CExpr [Stmt] [Stmt] + | SBlock (Bag Stmt) + | SIf CExpr (Bag Stmt) (Bag Stmt) | SVerbatim String -- ^ no implicit ';', just printed as-is deriving (Show) @@ -76,6 +76,7 @@ data CExpr = CELit String -- ^ inserted as-is, assumed no parentheses needed | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` | CEProj CExpr String -- ^ field projection: expr.field + | CEAddrOf CExpr -- ^ &expr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr @@ -93,7 +94,7 @@ printStmt indent = \case SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" SBlock stmts -> showString "{" - . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts] + . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") SIf cond b1 b2 -> showString "if (" . printCExpr 0 cond . showString ") " @@ -104,6 +105,7 @@ printStmt indent = \case -- * 0: top level -- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) -- * 2-...: various operators (see precTable) +-- * 80: address-of operator (&) -- * 98: inside unknown operator -- * 99: left of a field projection -- Unlisted operators are conservatively written with full parentheses. @@ -117,6 +119,7 @@ printCExpr d = \case | (n, e) <- pairs]) . showString "}" CEProj e name -> printCExpr 99 e . showString ("." ++ name) + CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e CECall n es -> showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" CEBinop e1 n e2 -> @@ -173,22 +176,23 @@ genStructName = \t -> "ty_" ++ gen t where TBool -> "b" gen (TAccum t) = 'C' : gen t -genStruct :: String -> Ty -> Maybe StructDecl +genStruct :: String -> Ty -> [StructDecl] genStruct name topty = case topty of TNil -> - Just $ StructDecl name "" com + [StructDecl name "" com] TPair a b -> - Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com + [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com] TEither a b -> -- 0 -> a, 1 -> b - Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com + [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com] TMaybe t -> -- 0 -> nothing, 1 -> just - Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com + [StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com] TArr n t -> - Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com + [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " *a;") com + ,StructDecl name (name ++ "_buf *buf;") com] TScal _ -> - Nothing + [] TAccum t -> - Just $ StructDecl name (repTy t ++ " a;") com + [StructDecl name (repTy t ++ " a;") com] where com = ppTy 0 topty @@ -202,7 +206,8 @@ genStructs ty = do if seen then pure () else do - -- already mark this struct as generated now, so we don't generate it twice + -- already mark this struct as generated now, so we don't generate it + -- twice (unnecessary because no recursive types, but y'know) lift $ modify (Set.insert name) case ty of @@ -214,13 +219,14 @@ genStructs ty = do TScal _ -> pure () TAccum t -> genStructs t - tell (maybe mempty pure (genStruct name ty)) + tell (BList (genStruct name ty)) genAllStructs :: Foldable t => t Ty -> [StructDecl] genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty data CompState = CompState { csStructs :: Set Ty + , csTopLevelDecls :: Bag String , csStmts :: Bag Stmt , csNextId :: Int } deriving (Show) @@ -230,8 +236,11 @@ type CompM a = State CompState a genId :: CompM Int genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +genName' :: String -> CompM String +genName' prefix = (prefix ++) . show <$> genId + genName :: CompM String -genName = ('x' :) . show <$> genId +genName = genName' "x" emit :: Stmt -> CompM () emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } @@ -249,13 +258,16 @@ emitStruct ty = do modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } return (genStructName ty') +emitTLD :: String -> CompM () +emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } + nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) compileToString :: SList STy env -> Ex env t -> String compileToString env expr = let args = nameEnv env - (res, s) = runState (compile' args expr) (CompState mempty mempty 1) + (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1) structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) (arg_pairs, arg_metrics) = @@ -268,6 +280,7 @@ compileToString env expr = ,showString "#include <stdlib.h>\n\n" ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs ,showString "\n" + ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] ,showString $ "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++ intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ @@ -319,6 +332,7 @@ serialise topty topval ptr off k = pokeByteOff ptr off (1 :: Word8) serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do + _ <- error "TODO serialisation of arrays is wrong after refcount introduction" pokeShape ptr off n sh let off1 = off + 8 * fromSNat n eltsz = sizeofSTy t @@ -358,6 +372,7 @@ deserialise topty ptr off = then return Nothing else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do + _ <- error "TODO deserialisation of arrays is wrong after refcount introduction" sh <- peekShape ptr off n let off1 = off + 8 * fromSNat n eltsz = sizeofSTy t @@ -420,13 +435,19 @@ peekShape ptr off = \case compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case - EVar _ _ i -> return $ CELit (getConst (slistIdx env i)) + EVar _ t i -> do + let Const var = slistIdx env i + case t of + STArr{} -> return $ CELit ("(++" ++ var ++ ".refc, " ++ var ++ ")") + _ -> return $ CELit var ELet _ rhs body -> do e <- compile' env rhs var <- genName emit $ SVarDecl True (repSTy (typeOf rhs)) var e - compile' (Const var `SCons` env) body + rete <- compile' (Const var `SCons` env) body + releaseVarAlways (typeOf rhs) var + return rete EPair _ a b -> do name <- emitStruct (STPair (typeOf a) (typeOf b)) @@ -434,9 +455,25 @@ compile' env = \case e2 <- compile' env b return $ CEStruct name [("a", e1), ("b", e2)] - EFst _ e -> CEProj <$> compile' env e <*> pure "a" - - ESnd _ e -> CEProj <$> compile' env e <*> pure "b" + EFst _ e -> do + let STPair _ t2 = typeOf e + e' <- compile' env e + case releaseVar t2 of + Nothing -> return $ CEProj e' "a" + Just f -> do var <- genName + emit $ SVarDecl True (repSTy (typeOf e)) var e' + f (var ++ ".b") + return $ CEProj (CELit var) "a" + + ESnd _ e -> do + let STPair t1 _ = typeOf e + e' <- compile' env e + case releaseVar t1 of + Nothing -> return $ CEProj e' "b" + Just f -> do var <- genName + emit $ SVarDecl True (repSTy (typeOf e)) var e' + f (var ++ ".a") + return $ CEProj (CELit var) "b" ENil _ -> do name <- emitStruct STNil @@ -459,28 +496,28 @@ compile' env = \case retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SIf e1 - (stmts2 <> pure (SAsg retvar e2)) - (stmts3 <> pure (SAsg retvar e3)) + (BList stmts2 <> pure (SAsg retvar e2)) + (BList stmts3 <> pure (SAsg retvar e3)) return (CELit retvar) ECase _ e a b -> do let STEither t1 t2 = typeOf e e1 <- compile' env e var <- genName - fieldvar <- genName - (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a - (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b + -- I know those are not variable names, but it's fine, probably + (e2, stmts2) <- scope $ compile' (Const (var ++ ".a") `SCons` env) a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".b") `SCons` env) b + ((), stmtsRel1) <- scope $ releaseVarAlways t1 (var ++ ".a") + ((), stmtsRel2) <- scope $ releaseVarAlways t2 (var ++ ".b") retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (pure (SVarDecl True (repSTy t1) fieldvar - (CEProj (CELit var) "a")) - <> stmts2 + (BList stmts2 + <> BList stmtsRel1 <> pure (SAsg retvar e2)) - (pure (SVarDecl True (repSTy t2) fieldvar - (CEProj (CELit var) "b")) - <> stmts3 + (BList stmts3 + <> BList stmtsRel2 <> pure (SAsg retvar e3)))) return (CELit retvar) @@ -494,26 +531,37 @@ compile' env = \case return $ CEStruct name [("tag", CELit "1"), ("a", e1)] EMaybe _ a b e -> do + let STMaybe t = typeOf e e1 <- compile' env e var <- genName - fieldvar <- genName (e2, stmts2) <- scope $ compile' env a - (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b + (e3, stmts3) <- scope $ compile' (Const (var ++ ".a") `SCons` env) b + ((), stmtsRel) <- scope $ releaseVarAlways t (var ++ ".a") retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 + (BList stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl True (repSTy (typeOf b)) fieldvar - (CEProj (CELit var) "a")) - <> stmts3 + (BList stmts3 + <> BList stmtsRel <> pure (SAsg retvar e3)))) return (CELit retvar) - -- EConstArr _ n t arr -> do - -- name <- emitStruct (STArr n (STScal t)) - -- error "TODO" + EConstArr _ n t (Array sh vec) -> do + strname <- emitStruct (STArr n (STScal t)) + tldname <- genName' "carray" + emitTLD $ "static const " ++ repSTy (STScal t) ++ " " ++ + tldname ++ "[" ++ show (shapeSize sh) ++ "] = {" ++ + intercalate "," (map (compileScal False t) (toList vec)) ++ + "};" + -- Give it a refcount of _half_ the size_t max, so that it can be + -- incremented and decremented at will and will "never" reach anything + -- where something happens + emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ "_buf = " ++ + "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ + ".refc = SIZE_MAX/2, .a = " ++ tldname ++ "};" + return (CEStruct strname [("buf", CEAddrOf (CELit (tldname ++ "_buf")))]) -- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b) @@ -529,12 +577,7 @@ compile' env = \case -- EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e) - EConst _ t x -> case t of - STI32 -> return $ CELit $ "(int32_t)" ++ show x - STI64 -> return $ CELit $ "(int64_t)" ++ show x - STF32 -> return $ CELit $ show x ++ "f" - STF64 -> return $ CELit $ show x - STBool -> return $ CELit $ if x then "1" else "0" + EConst _ t x -> return $ CELit $ compileScal True t x -- EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e) @@ -571,6 +614,57 @@ compile' env = \case _ -> error "Compile: not implemented" +-- | Decrement reference counts in the components of the given variable. +releaseVar :: STy a -> Maybe (String -> CompM ()) +releaseVar ty = + let tree = makeReleaseTree ty + in case tree of RTNoop -> Nothing + _ -> Just $ \var -> releaseVar' var tree + +releaseVarAlways :: STy a -> String -> CompM () +releaseVarAlways ty var = maybe (pure ()) ($ var) (releaseVar ty) + +data ReleaseTree = RTArray -- ^ we've arrived at an array we need to decrement the refcount of + | RTNoop -- ^ don't do anything here + | RTProj String ReleaseTree -- ^ descend one field deeper + | RTCondTag ReleaseTree ReleaseTree -- ^ if tag is 0, first; if 1, second + | RTBoth ReleaseTree ReleaseTree -- ^ do both these paths + +smartRTProj :: String -> ReleaseTree -> ReleaseTree +smartRTProj _ RTNoop = RTNoop +smartRTProj field t = RTProj field t + +smartRTCondTag :: ReleaseTree -> ReleaseTree -> ReleaseTree +smartRTCondTag RTNoop RTNoop = RTNoop +smartRTCondTag t t' = RTCondTag t t' + +smartRTBoth :: ReleaseTree -> ReleaseTree -> ReleaseTree +smartRTBoth RTNoop t = t +smartRTBoth t RTNoop = t +smartRTBoth t t' = RTBoth t t' + +makeReleaseTree :: STy a -> ReleaseTree +makeReleaseTree STNil = RTNoop +makeReleaseTree (STPair a b) = smartRTBoth (smartRTProj "a" (makeReleaseTree a)) + (smartRTProj "b" (makeReleaseTree b)) +makeReleaseTree (STEither a b) = smartRTCondTag (smartRTProj "a" (makeReleaseTree a)) + (smartRTProj "b" (makeReleaseTree b)) +makeReleaseTree (STMaybe t) = smartRTCondTag RTNoop (makeReleaseTree t) +makeReleaseTree (STArr _ _) = RTArray +makeReleaseTree (STScal _) = RTNoop +makeReleaseTree (STAccum _) = RTNoop + +releaseVar' :: String -> ReleaseTree -> CompM () +releaseVar' path RTArray = emit $ SVerbatim (path ++ "--;") +releaseVar' _ RTNoop = pure () +releaseVar' path (RTProj field t) = releaseVar' (path ++ "." ++ field) t +releaseVar' path (RTCondTag t1 t2) = do + ((), stmts1) <- scope $ releaseVar' path t1 + ((), stmts2) <- scope $ releaseVar' path t2 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2) +releaseVar' path (RTBoth t1 t2) = releaseVar' path t1 >> releaseVar' path t2 + + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @(State CompState) $ CECall cop [e1] @@ -616,5 +710,14 @@ compileOpPair op e1 e2 = do OIDiv _ -> binary "/" _ -> error "compileOpPair: got unary operator" +-- | Bool: whether to ensure that the literal itself already has the appropriate type +compileScal :: Bool -> SScalTy t -> ScalRep t -> String +compileScal pedantic typ x = case typ of + STI32 -> (if pedantic then "(int32_t)" else "") ++ show x + STI64 -> (if pedantic then "(int64_t)" else "") ++ show x + STF32 -> show x ++ "f" + STF64 -> show x + STBool -> if x then "1" else "0" + compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index 163be2b..83fcdad 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -28,7 +28,7 @@ buildKernel csource funnames = do path <- mkdtemp template let outso = path ++ "/out.so" - let args = ["-O3", "-march=native", "-shared", "-fPIC", "-x", "c", "-o", outso, "-"] + let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-"] _ <- readProcess "gcc" args csource hPutStrLn stderr $ "[chad] loading kernel " ++ path diff --git a/src/Data.hs b/src/Data.hs index 8005737..00790fe 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -129,7 +129,7 @@ vecGenerate = \n f -> go n f SZ unsafeCoerceRefl :: a :~: b unsafeCoerceRefl = unsafeCoerce Refl -data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t] +data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t] deriving (Show, Functor, Foldable, Traversable) -- | This instance is mostly there just for 'pure' @@ -139,6 +139,7 @@ instance Applicative Bag where BOne f <*> b = f <$> b BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b) BMany bs <*> b = BMany (map (<*> b) bs) + BList bs <*> b = BMany (map (<$> b) bs) instance Semigroup (Bag t) where (<>) = BTwo instance Monoid (Bag t) where mempty = BNone |