summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-02 11:39:37 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-02 11:39:37 +0100
commit54762c98901b21468fa9ff4205107360c3096cd3 (patch)
treea765a9af9a19eaae5c4857f9ff4a9d3188708e52
parent0fffb5731271a551afcf08878cb021ead8dd1dae (diff)
compile: Compile constant array literals
-rw-r--r--src/Compile.hs163
-rw-r--r--src/Compile/Exec.hs2
-rw-r--r--src/Language.hs2
3 files changed, 94 insertions, 73 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 564f697..2a23561 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE TypeApplications #-}
module Compile (compile) where
+import Control.Monad (when)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
@@ -19,6 +20,10 @@ import Data.Set (Set)
import Data.Some
import qualified Data.Vector as V
import Foreign
+import System.IO (hPutStrLn, stderr)
+
+import Prelude hiding ((^))
+import qualified Prelude
import Array
import AST
@@ -28,6 +33,11 @@ import Data
import Interpreter.Rep
+-- :m *Example Compile AST.UnMonoid
+-- :seti -XOverloadedLabels -XGADTs
+-- (($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i))
+
+
-- In shape and index arrays, the innermost dimension is on the right (last index).
-- TODO: array lifetimes in C?
@@ -36,7 +46,9 @@ import Interpreter.Rep
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
- lib <- buildKernel (compileToString env expr) ["kernel"]
+ let source = compileToString env expr
+ hPutStrLn stderr $ "Generated C source: <<<\n" ++ source ++ ">>>"
+ lib <- buildKernel source ["kernel"]
let arg_metrics = reverse (unSList metricsSTy env)
(arg_offsets, result_offset) = computeStructOffsets arg_metrics
@@ -85,7 +97,7 @@ data CExpr
printStructDecl :: StructDecl -> ShowS
printStructDecl (StructDecl name contents comment) =
showString "typedef struct { " . showString contents . showString " } " . showString name
- . showString ("; // " ++ comment)
+ . showString ";" . (if null comment then id else showString (" // " ++ comment))
printStmt :: Int -> Stmt -> ShowS
printStmt indent = \case
@@ -187,7 +199,8 @@ genStruct name topty = case topty of
TMaybe t -> -- 0 -> nothing, 1 -> just
[StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com]
TArr n t ->
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " *a;") com
+ -- The buffer is trailed by a VLA for the actual array data.
+ [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " a[];") ""
,StructDecl name (name ++ "_buf *buf;") com]
TScal _ ->
[]
@@ -289,7 +302,7 @@ compileToString env expr =
,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n"
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString " if (sizeof(void*) != 8) { abort(); }\n"
+ ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { abort(); }\n"
,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
concat (map (\((arg, typ), off, idx) ->
"\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
@@ -372,11 +385,14 @@ deserialise topty ptr off =
then return Nothing
else Just <$> deserialise t ptr (off + alignmentSTy t)
STArr n t -> do
- _ <- error "TODO deserialisation of arrays is wrong after refcount introduction"
- sh <- peekShape ptr off n
- let off1 = off + 8 * fromSNat n
+ bufptr <- peekByteOff @(Ptr ()) ptr off
+ sh <- peekShape bufptr 0 n
+ refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
+ let off1 = 8 * fromSNat n + 8
eltsz = sizeofSTy t
- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t ptr (off1 + i * eltsz))
+ arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
+ when (refc < 2 ^ 62) $ free bufptr
+ return arr
STScal sty -> case sty of
STI32 -> peekByteOff @Int32 ptr off
STI64 -> peekByteOff @Int64 ptr off
@@ -438,7 +454,7 @@ compile' env = \case
EVar _ t i -> do
let Const var = slistIdx env i
case t of
- STArr{} -> return $ CELit ("(++" ++ var ++ ".refc, " ++ var ++ ")")
+ STArr{} -> return $ CELit ("(++" ++ var ++ "->buf.refc, " ++ var ++ ")")
_ -> return $ CELit var
ELet _ rhs body -> do
@@ -446,7 +462,7 @@ compile' env = \case
var <- genName
emit $ SVarDecl True (repSTy (typeOf rhs)) var e
rete <- compile' (Const var `SCons` env) body
- releaseVarAlways (typeOf rhs) var
+ incrementVarAlways Decrement (typeOf rhs) var
return rete
EPair _ a b -> do
@@ -458,7 +474,7 @@ compile' env = \case
EFst _ e -> do
let STPair _ t2 = typeOf e
e' <- compile' env e
- case releaseVar t2 of
+ case incrementVar Decrement t2 of
Nothing -> return $ CEProj e' "a"
Just f -> do var <- genName
emit $ SVarDecl True (repSTy (typeOf e)) var e'
@@ -468,7 +484,7 @@ compile' env = \case
ESnd _ e -> do
let STPair t1 _ = typeOf e
e' <- compile' env e
- case releaseVar t1 of
+ case incrementVar Decrement t1 of
Nothing -> return $ CEProj e' "b"
Just f -> do var <- genName
emit $ SVarDecl True (repSTy (typeOf e)) var e'
@@ -507,8 +523,8 @@ compile' env = \case
-- I know those are not variable names, but it's fine, probably
(e2, stmts2) <- scope $ compile' (Const (var ++ ".a") `SCons` env) a
(e3, stmts3) <- scope $ compile' (Const (var ++ ".b") `SCons` env) b
- ((), stmtsRel1) <- scope $ releaseVarAlways t1 (var ++ ".a")
- ((), stmtsRel2) <- scope $ releaseVarAlways t2 (var ++ ".b")
+ ((), stmtsRel1) <- scope $ incrementVarAlways Decrement t1 (var ++ ".a")
+ ((), stmtsRel2) <- scope $ incrementVarAlways Decrement t2 (var ++ ".b")
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
@@ -536,7 +552,7 @@ compile' env = \case
var <- genName
(e2, stmts2) <- scope $ compile' env a
(e3, stmts3) <- scope $ compile' (Const (var ++ ".a") `SCons` env) b
- ((), stmtsRel) <- scope $ releaseVarAlways t (var ++ ".a")
+ ((), stmtsRel) <- scope $ incrementVarAlways Decrement t (var ++ ".a")
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
@@ -550,18 +566,14 @@ compile' env = \case
EConstArr _ n t (Array sh vec) -> do
strname <- emitStruct (STArr n (STScal t))
- tldname <- genName' "carray"
- emitTLD $ "static const " ++ repSTy (STScal t) ++ " " ++
- tldname ++ "[" ++ show (shapeSize sh) ++ "] = {" ++
- intercalate "," (map (compileScal False t) (toList vec)) ++
- "};"
+ tldname <- genName' "carraybuf"
-- Give it a refcount of _half_ the size_t max, so that it can be
-- incremented and decremented at will and will "never" reach anything
-- where something happens
- emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ "_buf = " ++
+ emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++
"(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++
- ".refc = SIZE_MAX/2, .a = " ++ tldname ++ "};"
- return (CEStruct strname [("buf", CEAddrOf (CELit (tldname ++ "_buf")))])
+ ".refc = (size_t)1<<63, .a = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
+ return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
-- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b)
@@ -614,55 +626,61 @@ compile' env = \case
_ -> error "Compile: not implemented"
--- | Decrement reference counts in the components of the given variable.
-releaseVar :: STy a -> Maybe (String -> CompM ())
-releaseVar ty =
- let tree = makeReleaseTree ty
- in case tree of RTNoop -> Nothing
- _ -> Just $ \var -> releaseVar' var tree
-
-releaseVarAlways :: STy a -> String -> CompM ()
-releaseVarAlways ty var = maybe (pure ()) ($ var) (releaseVar ty)
-
-data ReleaseTree = RTArray -- ^ we've arrived at an array we need to decrement the refcount of
- | RTNoop -- ^ don't do anything here
- | RTProj String ReleaseTree -- ^ descend one field deeper
- | RTCondTag ReleaseTree ReleaseTree -- ^ if tag is 0, first; if 1, second
- | RTBoth ReleaseTree ReleaseTree -- ^ do both these paths
-
-smartRTProj :: String -> ReleaseTree -> ReleaseTree
-smartRTProj _ RTNoop = RTNoop
-smartRTProj field t = RTProj field t
-
-smartRTCondTag :: ReleaseTree -> ReleaseTree -> ReleaseTree
-smartRTCondTag RTNoop RTNoop = RTNoop
-smartRTCondTag t t' = RTCondTag t t'
-
-smartRTBoth :: ReleaseTree -> ReleaseTree -> ReleaseTree
-smartRTBoth RTNoop t = t
-smartRTBoth t RTNoop = t
-smartRTBoth t t' = RTBoth t t'
-
-makeReleaseTree :: STy a -> ReleaseTree
-makeReleaseTree STNil = RTNoop
-makeReleaseTree (STPair a b) = smartRTBoth (smartRTProj "a" (makeReleaseTree a))
- (smartRTProj "b" (makeReleaseTree b))
-makeReleaseTree (STEither a b) = smartRTCondTag (smartRTProj "a" (makeReleaseTree a))
- (smartRTProj "b" (makeReleaseTree b))
-makeReleaseTree (STMaybe t) = smartRTCondTag RTNoop (makeReleaseTree t)
-makeReleaseTree (STArr _ _) = RTArray
-makeReleaseTree (STScal _) = RTNoop
-makeReleaseTree (STAccum _) = RTNoop
-
-releaseVar' :: String -> ReleaseTree -> CompM ()
-releaseVar' path RTArray = emit $ SVerbatim (path ++ "--;")
-releaseVar' _ RTNoop = pure ()
-releaseVar' path (RTProj field t) = releaseVar' (path ++ "." ++ field) t
-releaseVar' path (RTCondTag t1 t2) = do
- ((), stmts1) <- scope $ releaseVar' path t1
- ((), stmts2) <- scope $ releaseVar' path t2
+data Increment = Increment | Decrement
+ deriving (Show)
+
+-- | Increment reference counts in the components of the given variable.
+incrementVar :: Increment -> STy a -> Maybe (String -> CompM ())
+incrementVar inc ty =
+ let tree = makeArrayTree ty
+ in case tree of ATNoop -> Nothing
+ _ -> Just $ \var -> incrementVar' inc var tree
+
+incrementVarAlways :: Increment -> STy a -> String -> CompM ()
+incrementVarAlways inc ty var = maybe (pure ()) ($ var) (incrementVar inc ty)
+
+data ArrayTree = ATArray -- ^ we've arrived at an array we need to decrement the refcount of
+ | ATNoop -- ^ don't do anything here
+ | ATProj String ArrayTree -- ^ descend one field deeper
+ | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
+ | ATBoth ArrayTree ArrayTree -- ^ do both these paths
+
+smartATProj :: String -> ArrayTree -> ArrayTree
+smartATProj _ ATNoop = ATNoop
+smartATProj field t = ATProj field t
+
+smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
+smartATCondTag ATNoop ATNoop = ATNoop
+smartATCondTag t t' = ATCondTag t t'
+
+smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
+smartATBoth ATNoop t = t
+smartATBoth t ATNoop = t
+smartATBoth t t' = ATBoth t t'
+
+makeArrayTree :: STy a -> ArrayTree
+makeArrayTree STNil = ATNoop
+makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
+ (smartATProj "b" (makeArrayTree b))
+makeArrayTree (STEither a b) = smartATCondTag (smartATProj "a" (makeArrayTree a))
+ (smartATProj "b" (makeArrayTree b))
+makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t)
+makeArrayTree (STArr _ _) = ATArray
+makeArrayTree (STScal _) = ATNoop
+makeArrayTree (STAccum _) = ATNoop
+
+incrementVar' :: Increment -> String -> ArrayTree -> CompM ()
+incrementVar' inc path ATArray =
+ let op = case inc of Increment -> "++"
+ Decrement -> "--"
+ in emit $ SVerbatim (path ++ "->buf.refc" ++ op ++ ";")
+incrementVar' _ _ ATNoop = pure ()
+incrementVar' inc path (ATProj field t) = incrementVar' inc (path ++ "." ++ field) t
+incrementVar' inc path (ATCondTag t1 t2) = do
+ ((), stmts1) <- scope $ incrementVar' inc path t1
+ ((), stmts2) <- scope $ incrementVar' inc path t2
emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2)
-releaseVar' path (RTBoth t1 t2) = releaseVar' path t1 >> releaseVar' path t2
+incrementVar' inc path (ATBoth t1 t2) = incrementVar' inc path t1 >> incrementVar' inc path t2
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
@@ -721,3 +739,6 @@ compileScal pedantic typ x = case typ of
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
+
+(^) :: Num a => a -> Int -> a
+(^) = (Prelude.^)
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index 83fcdad..7c5cb79 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -28,7 +28,7 @@ buildKernel csource funnames = do
path <- mkdtemp template
let outso = path ++ "/out.so"
- let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-"]
+ let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-", "-Wall", "-Wextra", "-Wno-unused-parameter"]
_ <- readProcess "gcc" args csource
hPutStrLn stderr $ "[chad] loading kernel " ++ path
diff --git a/src/Language.hs b/src/Language.hs
index a7737e0..70cc4f9 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -69,7 +69,7 @@ inr = NEInr knownTy
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
-constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
+constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
let ty = knownScalTy
in case scalRepIsShow ty of