summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-02 00:21:36 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-02 00:21:36 +0100
commit0fffb5731271a551afcf08878cb021ead8dd1dae (patch)
treefd4fff6010db0fb63515930708b3d8bfd234c367
parent0ebdcff2aa06ee95f95597f2984e2cd335899d37 (diff)
compile: WIP reference-counted arrays
-rw-r--r--src/Array.hs7
-rw-r--r--src/Compile.hs195
-rw-r--r--src/Compile/Exec.hs2
-rw-r--r--src/Data.hs3
4 files changed, 159 insertions, 48 deletions
diff --git a/src/Array.hs b/src/Array.hs
index 82c3f31..059600f 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -63,6 +63,13 @@ emptyShape (SS m) = emptyShape m `ShCons` 0
enumShape :: Shape n -> [Index n]
enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1]
+shapeToList :: Shape n -> [Int]
+shapeToList = go []
+ where
+ go :: [Int] -> Shape n -> [Int]
+ go suff ShNil = suff
+ go suff (sh `ShCons` n) = go (n:suff) sh
+
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
diff --git a/src/Compile.hs b/src/Compile.hs
index 582a1df..564f697 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -67,8 +67,8 @@ data Stmt
= SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side
| SVarDeclUninit String String -- ^ type, variable name (no initialiser)
| SAsg String CExpr -- ^ variable name, right-hand side
- | SBlock [Stmt]
- | SIf CExpr [Stmt] [Stmt]
+ | SBlock (Bag Stmt)
+ | SIf CExpr (Bag Stmt) (Bag Stmt)
| SVerbatim String -- ^ no implicit ';', just printed as-is
deriving (Show)
@@ -76,6 +76,7 @@ data CExpr
= CELit String -- ^ inserted as-is, assumed no parentheses needed
| CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
| CEProj CExpr String -- ^ field projection: expr.field
+ | CEAddrOf CExpr -- ^ &expr
| CECall String [CExpr] -- ^ function(arg1, ..., argn)
| CEBinop CExpr String CExpr -- ^ expr + expr
| CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
@@ -93,7 +94,7 @@ printStmt indent = \case
SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
SBlock stmts ->
showString "{"
- . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts]
+ . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
. showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
SIf cond b1 b2 ->
showString "if (" . printCExpr 0 cond . showString ") "
@@ -104,6 +105,7 @@ printStmt indent = \case
-- * 0: top level
-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)
-- * 2-...: various operators (see precTable)
+-- * 80: address-of operator (&)
-- * 98: inside unknown operator
-- * 99: left of a field projection
-- Unlisted operators are conservatively written with full parentheses.
@@ -117,6 +119,7 @@ printCExpr d = \case
| (n, e) <- pairs])
. showString "}"
CEProj e name -> printCExpr 99 e . showString ("." ++ name)
+ CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
CECall n es ->
showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
CEBinop e1 n e2 ->
@@ -173,22 +176,23 @@ genStructName = \t -> "ty_" ++ gen t where
TBool -> "b"
gen (TAccum t) = 'C' : gen t
-genStruct :: String -> Ty -> Maybe StructDecl
+genStruct :: String -> Ty -> [StructDecl]
genStruct name topty = case topty of
TNil ->
- Just $ StructDecl name "" com
+ [StructDecl name "" com]
TPair a b ->
- Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com
+ [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com]
TEither a b -> -- 0 -> a, 1 -> b
- Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com
+ [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com]
TMaybe t -> -- 0 -> nothing, 1 -> just
- Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com
+ [StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com]
TArr n t ->
- Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com
+ [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " *a;") com
+ ,StructDecl name (name ++ "_buf *buf;") com]
TScal _ ->
- Nothing
+ []
TAccum t ->
- Just $ StructDecl name (repTy t ++ " a;") com
+ [StructDecl name (repTy t ++ " a;") com]
where
com = ppTy 0 topty
@@ -202,7 +206,8 @@ genStructs ty = do
if seen
then pure ()
else do
- -- already mark this struct as generated now, so we don't generate it twice
+ -- already mark this struct as generated now, so we don't generate it
+ -- twice (unnecessary because no recursive types, but y'know)
lift $ modify (Set.insert name)
case ty of
@@ -214,13 +219,14 @@ genStructs ty = do
TScal _ -> pure ()
TAccum t -> genStructs t
- tell (maybe mempty pure (genStruct name ty))
+ tell (BList (genStruct name ty))
genAllStructs :: Foldable t => t Ty -> [StructDecl]
genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty
data CompState = CompState
{ csStructs :: Set Ty
+ , csTopLevelDecls :: Bag String
, csStmts :: Bag Stmt
, csNextId :: Int }
deriving (Show)
@@ -230,8 +236,11 @@ type CompM a = State CompState a
genId :: CompM Int
genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+genName' :: String -> CompM String
+genName' prefix = (prefix ++) . show <$> genId
+
genName :: CompM String
-genName = ('x' :) . show <$> genId
+genName = genName' "x"
emit :: Stmt -> CompM ()
emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt }
@@ -249,13 +258,16 @@ emitStruct ty = do
modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) }
return (genStructName ty')
+emitTLD :: String -> CompM ()
+emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
+
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
compileToString :: SList STy env -> Ex env t -> String
compileToString env expr =
let args = nameEnv env
- (res, s) = runState (compile' args expr) (CompState mempty mempty 1)
+ (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1)
structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
(arg_pairs, arg_metrics) =
@@ -268,6 +280,7 @@ compileToString env expr =
,showString "#include <stdlib.h>\n\n"
,compose $ map (\sd -> printStructDecl sd . showString "\n") structs
,showString "\n"
+ ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
,showString $
"static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++
intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
@@ -319,6 +332,7 @@ serialise topty topval ptr off k =
pokeByteOff ptr off (1 :: Word8)
serialise t x ptr (off + alignmentSTy t) k
(STArr n t, Array sh vec) -> do
+ _ <- error "TODO serialisation of arrays is wrong after refcount introduction"
pokeShape ptr off n sh
let off1 = off + 8 * fromSNat n
eltsz = sizeofSTy t
@@ -358,6 +372,7 @@ deserialise topty ptr off =
then return Nothing
else Just <$> deserialise t ptr (off + alignmentSTy t)
STArr n t -> do
+ _ <- error "TODO deserialisation of arrays is wrong after refcount introduction"
sh <- peekShape ptr off n
let off1 = off + 8 * fromSNat n
eltsz = sizeofSTy t
@@ -420,13 +435,19 @@ peekShape ptr off = \case
compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
compile' env = \case
- EVar _ _ i -> return $ CELit (getConst (slistIdx env i))
+ EVar _ t i -> do
+ let Const var = slistIdx env i
+ case t of
+ STArr{} -> return $ CELit ("(++" ++ var ++ ".refc, " ++ var ++ ")")
+ _ -> return $ CELit var
ELet _ rhs body -> do
e <- compile' env rhs
var <- genName
emit $ SVarDecl True (repSTy (typeOf rhs)) var e
- compile' (Const var `SCons` env) body
+ rete <- compile' (Const var `SCons` env) body
+ releaseVarAlways (typeOf rhs) var
+ return rete
EPair _ a b -> do
name <- emitStruct (STPair (typeOf a) (typeOf b))
@@ -434,9 +455,25 @@ compile' env = \case
e2 <- compile' env b
return $ CEStruct name [("a", e1), ("b", e2)]
- EFst _ e -> CEProj <$> compile' env e <*> pure "a"
-
- ESnd _ e -> CEProj <$> compile' env e <*> pure "b"
+ EFst _ e -> do
+ let STPair _ t2 = typeOf e
+ e' <- compile' env e
+ case releaseVar t2 of
+ Nothing -> return $ CEProj e' "a"
+ Just f -> do var <- genName
+ emit $ SVarDecl True (repSTy (typeOf e)) var e'
+ f (var ++ ".b")
+ return $ CEProj (CELit var) "a"
+
+ ESnd _ e -> do
+ let STPair t1 _ = typeOf e
+ e' <- compile' env e
+ case releaseVar t1 of
+ Nothing -> return $ CEProj e' "b"
+ Just f -> do var <- genName
+ emit $ SVarDecl True (repSTy (typeOf e)) var e'
+ f (var ++ ".a")
+ return $ CEProj (CELit var) "b"
ENil _ -> do
name <- emitStruct STNil
@@ -459,28 +496,28 @@ compile' env = \case
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SIf e1
- (stmts2 <> pure (SAsg retvar e2))
- (stmts3 <> pure (SAsg retvar e3))
+ (BList stmts2 <> pure (SAsg retvar e2))
+ (BList stmts3 <> pure (SAsg retvar e3))
return (CELit retvar)
ECase _ e a b -> do
let STEither t1 t2 = typeOf e
e1 <- compile' env e
var <- genName
- fieldvar <- genName
- (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a
- (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
+ -- I know those are not variable names, but it's fine, probably
+ (e2, stmts2) <- scope $ compile' (Const (var ++ ".a") `SCons` env) a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".b") `SCons` env) b
+ ((), stmtsRel1) <- scope $ releaseVarAlways t1 (var ++ ".a")
+ ((), stmtsRel2) <- scope $ releaseVarAlways t2 (var ++ ".b")
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (pure (SVarDecl True (repSTy t1) fieldvar
- (CEProj (CELit var) "a"))
- <> stmts2
+ (BList stmts2
+ <> BList stmtsRel1
<> pure (SAsg retvar e2))
- (pure (SVarDecl True (repSTy t2) fieldvar
- (CEProj (CELit var) "b"))
- <> stmts3
+ (BList stmts3
+ <> BList stmtsRel2
<> pure (SAsg retvar e3))))
return (CELit retvar)
@@ -494,26 +531,37 @@ compile' env = \case
return $ CEStruct name [("tag", CELit "1"), ("a", e1)]
EMaybe _ a b e -> do
+ let STMaybe t = typeOf e
e1 <- compile' env e
var <- genName
- fieldvar <- genName
(e2, stmts2) <- scope $ compile' env a
- (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".a") `SCons` env) b
+ ((), stmtsRel) <- scope $ releaseVarAlways t (var ++ ".a")
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2
+ (BList stmts2
<> pure (SAsg retvar e2))
- (pure (SVarDecl True (repSTy (typeOf b)) fieldvar
- (CEProj (CELit var) "a"))
- <> stmts3
+ (BList stmts3
+ <> BList stmtsRel
<> pure (SAsg retvar e3))))
return (CELit retvar)
- -- EConstArr _ n t arr -> do
- -- name <- emitStruct (STArr n (STScal t))
- -- error "TODO"
+ EConstArr _ n t (Array sh vec) -> do
+ strname <- emitStruct (STArr n (STScal t))
+ tldname <- genName' "carray"
+ emitTLD $ "static const " ++ repSTy (STScal t) ++ " " ++
+ tldname ++ "[" ++ show (shapeSize sh) ++ "] = {" ++
+ intercalate "," (map (compileScal False t) (toList vec)) ++
+ "};"
+ -- Give it a refcount of _half_ the size_t max, so that it can be
+ -- incremented and decremented at will and will "never" reach anything
+ -- where something happens
+ emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ "_buf = " ++
+ "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++
+ ".refc = SIZE_MAX/2, .a = " ++ tldname ++ "};"
+ return (CEStruct strname [("buf", CEAddrOf (CELit (tldname ++ "_buf")))])
-- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b)
@@ -529,12 +577,7 @@ compile' env = \case
-- EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e)
- EConst _ t x -> case t of
- STI32 -> return $ CELit $ "(int32_t)" ++ show x
- STI64 -> return $ CELit $ "(int64_t)" ++ show x
- STF32 -> return $ CELit $ show x ++ "f"
- STF64 -> return $ CELit $ show x
- STBool -> return $ CELit $ if x then "1" else "0"
+ EConst _ t x -> return $ CELit $ compileScal True t x
-- EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e)
@@ -571,6 +614,57 @@ compile' env = \case
_ -> error "Compile: not implemented"
+-- | Decrement reference counts in the components of the given variable.
+releaseVar :: STy a -> Maybe (String -> CompM ())
+releaseVar ty =
+ let tree = makeReleaseTree ty
+ in case tree of RTNoop -> Nothing
+ _ -> Just $ \var -> releaseVar' var tree
+
+releaseVarAlways :: STy a -> String -> CompM ()
+releaseVarAlways ty var = maybe (pure ()) ($ var) (releaseVar ty)
+
+data ReleaseTree = RTArray -- ^ we've arrived at an array we need to decrement the refcount of
+ | RTNoop -- ^ don't do anything here
+ | RTProj String ReleaseTree -- ^ descend one field deeper
+ | RTCondTag ReleaseTree ReleaseTree -- ^ if tag is 0, first; if 1, second
+ | RTBoth ReleaseTree ReleaseTree -- ^ do both these paths
+
+smartRTProj :: String -> ReleaseTree -> ReleaseTree
+smartRTProj _ RTNoop = RTNoop
+smartRTProj field t = RTProj field t
+
+smartRTCondTag :: ReleaseTree -> ReleaseTree -> ReleaseTree
+smartRTCondTag RTNoop RTNoop = RTNoop
+smartRTCondTag t t' = RTCondTag t t'
+
+smartRTBoth :: ReleaseTree -> ReleaseTree -> ReleaseTree
+smartRTBoth RTNoop t = t
+smartRTBoth t RTNoop = t
+smartRTBoth t t' = RTBoth t t'
+
+makeReleaseTree :: STy a -> ReleaseTree
+makeReleaseTree STNil = RTNoop
+makeReleaseTree (STPair a b) = smartRTBoth (smartRTProj "a" (makeReleaseTree a))
+ (smartRTProj "b" (makeReleaseTree b))
+makeReleaseTree (STEither a b) = smartRTCondTag (smartRTProj "a" (makeReleaseTree a))
+ (smartRTProj "b" (makeReleaseTree b))
+makeReleaseTree (STMaybe t) = smartRTCondTag RTNoop (makeReleaseTree t)
+makeReleaseTree (STArr _ _) = RTArray
+makeReleaseTree (STScal _) = RTNoop
+makeReleaseTree (STAccum _) = RTNoop
+
+releaseVar' :: String -> ReleaseTree -> CompM ()
+releaseVar' path RTArray = emit $ SVerbatim (path ++ "--;")
+releaseVar' _ RTNoop = pure ()
+releaseVar' path (RTProj field t) = releaseVar' (path ++ "." ++ field) t
+releaseVar' path (RTCondTag t1 t2) = do
+ ((), stmts1) <- scope $ releaseVar' path t1
+ ((), stmts2) <- scope $ releaseVar' path t2
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2)
+releaseVar' path (RTBoth t1 t2) = releaseVar' path t1 >> releaseVar' path t2
+
+
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
let unary cop = return @(State CompState) $ CECall cop [e1]
@@ -616,5 +710,14 @@ compileOpPair op e1 e2 = do
OIDiv _ -> binary "/"
_ -> error "compileOpPair: got unary operator"
+-- | Bool: whether to ensure that the literal itself already has the appropriate type
+compileScal :: Bool -> SScalTy t -> ScalRep t -> String
+compileScal pedantic typ x = case typ of
+ STI32 -> (if pedantic then "(int32_t)" else "") ++ show x
+ STI64 -> (if pedantic then "(int64_t)" else "") ++ show x
+ STF32 -> show x ++ "f"
+ STF64 -> show x
+ STBool -> if x then "1" else "0"
+
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index 163be2b..83fcdad 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -28,7 +28,7 @@ buildKernel csource funnames = do
path <- mkdtemp template
let outso = path ++ "/out.so"
- let args = ["-O3", "-march=native", "-shared", "-fPIC", "-x", "c", "-o", outso, "-"]
+ let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-"]
_ <- readProcess "gcc" args csource
hPutStrLn stderr $ "[chad] loading kernel " ++ path
diff --git a/src/Data.hs b/src/Data.hs
index 8005737..00790fe 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -129,7 +129,7 @@ vecGenerate = \n f -> go n f SZ
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = unsafeCoerce Refl
-data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t]
+data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t]
deriving (Show, Functor, Foldable, Traversable)
-- | This instance is mostly there just for 'pure'
@@ -139,6 +139,7 @@ instance Applicative Bag where
BOne f <*> b = f <$> b
BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b)
BMany bs <*> b = BMany (map (<*> b) bs)
+ BList bs <*> b = BMany (map (<$> b) bs)
instance Semigroup (Bag t) where (<>) = BTwo
instance Monoid (Bag t) where mempty = BNone