aboutsummaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs78
1 files changed, 37 insertions, 41 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 2b7cd9e..064e0b6 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -253,10 +253,9 @@ genStruct name topty = case topty of
[StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
- -- TODO: put shape in the main struct, not the buffer; it's constant, after all
-- TODO: no buffer if n = 0
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
+ [StructDecl (name ++ "_buf") ("size_t refc; " ++ repSTy t ++ " xs[];") ""
+ ,StructDecl name (name ++ "_buf *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com]
STScal _ ->
[]
STAccum t ->
@@ -481,19 +480,18 @@ serialise topty topval ptr off k =
serialise t x ptr (off + alignmentSTy t) k
(STArr n t, Array sh vec) -> do
let eltsz = sizeofSTy t
- allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
+ allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do
when debugRefc $
hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
pokeByteOff ptr off bufptr
+ pokeShape ptr (off + 8) n sh
- pokeShape bufptr 0 n sh
- pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63)
+ pokeByteOff @Word64 bufptr 0 (2 ^ 63)
- let off1 = fromSNat n * 8 + 8
- loop i
+ let loop i
| i == shapeSize sh = k
| otherwise =
- serialise t (vec V.! i) bufptr (off1 + i * eltsz) $
+ serialise t (vec V.! i) bufptr (8 + i * eltsz) $
loop (i+1)
loop 0
(STScal sty, x) -> case sty of
@@ -533,13 +531,12 @@ deserialise topty ptr off =
else Just <$> deserialise t ptr (off + alignmentSTy t)
STArr n t -> do
bufptr <- peekByteOff @(Ptr ()) ptr off
- sh <- peekShape bufptr 0 n
- refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
+ sh <- peekShape ptr (off + 8) n
+ refc <- peekByteOff @Word64 bufptr 0
when debugRefc $
hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
- let off1 = 8 * fromSNat n + 8
- eltsz = sizeofSTy t
- arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
+ let eltsz = sizeofSTy t
+ arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz))
when (refc < 2 ^ 62) $ free bufptr
return arr
STScal sty -> case sty of
@@ -577,7 +574,7 @@ metricsSTy (STLEither a b) =
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
-metricsSTy (STArr _ _) = (8, 8)
+metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n)
metricsSTy (STScal sty) = case sty of
STI32 -> (4, 4)
STI64 -> (8, 8)
@@ -754,9 +751,11 @@ compile' env = \case
-- incremented and decremented at will and will "never" reach anything
-- where something happens
emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++
- "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++
- ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
- return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
+ "(" ++ strname ++ "_buf){.refc = (size_t)1<<63, " ++
+ ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
+ return (CEStruct strname
+ [("buf", CEAddrOf (CELit tldname))
+ ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))])
EBuild _ n esh efun -> do
shname <- compileAssign "sh" env esh
@@ -771,7 +770,7 @@ compile' env = \case
emit $ SBlock $
pure (SVarDecl False "size_t" linivar (CELit "0"))
<> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
- (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx))))
+ (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx))))
| (ivar, dimidx) <- zip ivars [0::Int ..]]
(pure (SVarDecl True (repSTy (typeOf esh)) idxargname
(shapeTupFromLitVars n ivars))
@@ -800,7 +799,7 @@ compile' env = \case
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
@@ -846,7 +845,7 @@ compile' env = \case
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
let vecwid = 8 :: Int
ivar <- genName' "i"
@@ -935,7 +934,7 @@ compile' env = \case
when emitChecks $
forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
- (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]")))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
arrname ++ ".buf); return false;")
@@ -1044,16 +1043,16 @@ compile' env = \case
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
forM_ [0 .. fromSNat n - 1] $ \j -> do
- emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
+ emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]"))
"!="
- (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
+ (CELit (d ++ ".sh[" ++ show j ++ "]")))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
"dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
d ++ ".buf" ++
- concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
", " ++ s ++ ".buf" ++
- concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
"); " ++
"return false;")
mempty
@@ -1110,12 +1109,12 @@ compile' env = \case
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
"arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
"); " ++
"return false;")
@@ -1268,14 +1267,14 @@ toLinearIdx SZ _ _ = CELit "0"
toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
toLinearIdx (SS n) arrvar idxvar =
CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
- "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n)))))
+ "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n)))))
"+" (CELit (idxvar ++ ".b"))
-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
-- fromLinearIdx (SS n) arrvar idxvar = do
-- name <- genName
--- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
+-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]")))
-- _
data AllocMethod = Malloc | Calloc
@@ -1297,9 +1296,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do
(CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
emit $ SVarDecl True strname arrname $ CEStruct strname
[("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
- Calloc -> CECall "calloc_instr" [nbytesExpr])]
- forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
- emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
+ Calloc -> CECall "calloc_instr" [nbytesExpr])
+ ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))]
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
when debugRefc $
emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
@@ -1310,7 +1308,7 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
compileShapeQuery (SS n) var =
CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
[("a", compileShapeQuery n var)
- ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
+ ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))]
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
@@ -1319,7 +1317,7 @@ compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compi
-- | Takes a variable name for the array, not the buffer.
compileArrShapeComponents :: SNat n -> String -> [CExpr]
compileArrShapeComponents n var =
- [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
indexTupleComponents :: SNat n -> String -> [CExpr]
indexTupleComponents = \n var -> map CELit (toList (go n var))
@@ -1410,7 +1408,7 @@ compileExtremum nameBase opName operator env e = do
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
@@ -1490,7 +1488,7 @@ copyForWriting topty var = case topty of
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
emit $ SVerbatim $
"fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
- concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
");"
emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
@@ -1501,8 +1499,7 @@ copyForWriting topty var = case topty of
in BList
[SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
+ ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
,SAsg (name ++ ".buf->refc") (CELit "1")
,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
printCExpr 0 databytes ");"])
@@ -1519,8 +1516,7 @@ copyForWriting topty var = case topty of
name <- genName
emit $ SVarDecl False toptyname name
(CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
+ emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-- put the arrays in variables to cut short the not-quite-var chain