aboutsummaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs574
1 files changed, 311 insertions, 263 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 6ba3a39..f2063ee 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -34,6 +34,7 @@ import Foreign
import GHC.Exts (int2Word#, addr2Int#)
import GHC.Num (integerFromWord#)
import GHC.Ptr (Ptr(..))
+import GHC.Stack (HasCallStack)
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
import System.IO.Error (mkIOError, userErrorType)
@@ -45,6 +46,7 @@ import qualified Prelude
import Array
import AST
import AST.Pretty (ppSTy, ppExpr)
+import AST.Sparse.Types (isDense)
import Compile.Exec
import Data
import Interpreter.Rep
@@ -77,7 +79,7 @@ compile = \env expr -> do
let (source, offsets) = compileToString codeID env expr
when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
- lib <- buildKernel source ["kernel"]
+ lib <- buildKernel source "kernel"
let result_type = typeOf expr
result_size = sizeofSTy result_type
@@ -86,7 +88,7 @@ compile = \env expr -> do
allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
serialiseArguments args ptr $ do
- callKernelFun "kernel" lib ptr
+ callKernelFun lib ptr
ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
when (ok /= 1) $
ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
@@ -125,7 +127,7 @@ data CExpr
| CECall String [CExpr] -- ^ function(arg1, ..., argn)
| CEBinop CExpr String CExpr -- ^ expr + expr
| CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
- | CECast String CExpr -- ^ (<type)<expr>
+ | CECast String CExpr -- ^ (<type>)<expr>
deriving (Show)
printStructDecl :: StructDecl -> ShowS
@@ -214,23 +216,31 @@ repSTy (STScal st) = case st of
STBool -> "uint8_t"
repSTy t = genStructName t
-genStructName :: STy t -> String
-genStructName = \t -> "ty_" ++ gen t where
- -- all tags start with a letter, so the array mangling is unambiguous.
- gen :: STy t -> String
- gen STNil = "n"
- gen (STPair a b) = 'P' : gen a ++ gen b
- gen (STEither a b) = 'E' : gen a ++ gen b
- gen (STMaybe t) = 'M' : gen t
- gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
- gen (STScal st) = case st of
- STI32 -> "i"
- STI64 -> "j"
- STF32 -> "f"
- STF64 -> "d"
- STBool -> "b"
- gen (STAccum t) = 'C' : gen (fromSMTy t)
- gen (STLEither a b) = 'L' : gen a ++ gen b
+genStructName, genArrBufStructName :: STy t -> String
+(genStructName, genArrBufStructName) =
+ (\t -> "ty_" ++ gen t
+ ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension
+ t -> error $ "genArrBufStructName: not an array type: " ++ show t)
+ where
+ -- all tags start with a letter, so the array mangling is unambiguous.
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
+
+-- The subtrees contain structs used in the bodies of the structs in this node.
+data StructTree = TreeNode [StructDecl] [StructTree]
+ deriving (Show)
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
@@ -238,60 +248,56 @@ genStructName = \t -> "ty_" ++ gen t where
--
-- For accumulation it is important that for struct representations of monoid
-- types, the all-zero-bytes value corresponds to the zero value of that type.
-genStruct :: String -> STy t -> [StructDecl]
-genStruct name topty = case topty of
+buildStructTree :: STy t -> StructTree
+buildStructTree topty = case topty of
STNil ->
- [StructDecl name "" com]
+ TreeNode [StructDecl name "" com] []
STPair a b ->
- [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ [buildStructTree a, buildStructTree b]
STEither a b -> -- 0 -> l, 1 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
STMaybe t -> -- 0 -> nothing, 1 -> just
- [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ [buildStructTree t]
STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
- -- TODO: put shape in the main struct, not the buffer; it's constant, after all
-- TODO: no buffer if n = 0
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
+ TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") ""
+ ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com]
+ [buildStructTree t]
STScal _ ->
- []
+ TreeNode [] []
STAccum t ->
- [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
- STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
+ ,StructDecl name (name ++ "_buf *buf;") com]
+ [buildStructTree (fromSMTy t)]
where
+ name = genStructName topty
com = ppSTy 0 topty
-- State: already-generated (skippable) struct names
-- Writer: the structs in declaration order
-genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()
-genStructs ty = do
- let name = genStructName ty
- seen <- lift $ gets (name `Set.member`)
-
- if seen
- then pure ()
- else do
- -- already mark this struct as generated now, so we don't generate it
- -- twice (unnecessary because no recursive types, but y'know)
- lift $ modify (Set.insert name)
-
- () <- case ty of
- STNil -> pure ()
- STPair a b -> genStructs a >> genStructs b
- STEither a b -> genStructs a >> genStructs b
- STMaybe t -> genStructs t
- STArr _ t -> genStructs t
- STScal _ -> pure ()
- STAccum t -> genStructs (fromSMTy t)
- STLEither a b -> genStructs a >> genStructs b
-
- tell (BList (genStruct name ty))
+genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructTreeW (TreeNode these deps) = do
+ seen <- lift get
+ case filter ((`Set.notMember` seen) . nameOf) these of
+ [] -> pure ()
+ structs -> do
+ lift $ modify (Set.fromList (map nameOf structs) <>)
+ mapM_ genStructTreeW deps
+ tell (BList structs)
+ where
+ nameOf (StructDecl name _ _) = name
genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty
+genAllStructs tys =
+ let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys
+ in toList (evalState (execWriterT m) mempty)
data CompState = CompState
{ csStructs :: Set (Some STy)
@@ -340,6 +346,12 @@ emitStruct ty = CompM $ do
modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
return (genStructName ty)
+-- | Also returns the name of the array buffer struct
+emitArrStruct :: STy t -> CompM (String, String)
+emitArrStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty, genArrBufStructName ty)
+
emitTLD :: String -> CompM ()
emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
@@ -463,6 +475,15 @@ serialise topty topval ptr off k =
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
(STMaybe _, Nothing) -> do
pokeByteOff ptr off (0 :: Word8)
k
@@ -471,19 +492,18 @@ serialise topty topval ptr off k =
serialise t x ptr (off + alignmentSTy t) k
(STArr n t, Array sh vec) -> do
let eltsz = sizeofSTy t
- allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
+ allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do
when debugRefc $
hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
pokeByteOff ptr off bufptr
+ pokeShape ptr (off + 8) n sh
- pokeShape bufptr 0 n sh
- pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63)
+ pokeByteOff @Word64 bufptr 0 (2 ^ 63)
- let off1 = fromSNat n * 8 + 8
- loop i
+ let loop i
| i == shapeSize sh = k
| otherwise =
- serialise t (vec V.! i) bufptr (off1 + i * eltsz) $
+ serialise t (vec V.! i) bufptr (8 + i * eltsz) $
loop (i+1)
loop 0
(STScal sty, x) -> case sty of
@@ -493,15 +513,6 @@ serialise topty topval ptr off k =
STF64 -> pokeByteOff ptr off (x :: Double) >> k
STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
(STAccum{}, _) -> error "Cannot serialise accumulators"
- (STLEither _ _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STLEither a _, Just (Left x)) -> do
- pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STLEither _ b, Just (Right y)) -> do
- pokeByteOff ptr off (2 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
-- | Assumes that this is called at the correct alignment.
deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
@@ -518,6 +529,13 @@ deserialise topty ptr off =
if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
STMaybe t -> do
tag <- peekByteOff @Word8 ptr off
if tag == 0
@@ -525,13 +543,12 @@ deserialise topty ptr off =
else Just <$> deserialise t ptr (off + alignmentSTy t)
STArr n t -> do
bufptr <- peekByteOff @(Ptr ()) ptr off
- sh <- peekShape bufptr 0 n
- refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
+ sh <- peekShape ptr (off + 8) n
+ refc <- peekByteOff @Word64 bufptr 0
when debugRefc $
hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
- let off1 = 8 * fromSNat n + 8
- eltsz = sizeofSTy t
- arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
+ let eltsz = sizeofSTy t
+ arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz))
when (refc < 2 ^ 62) $ free bufptr
return arr
STScal sty -> case sty of
@@ -541,13 +558,6 @@ deserialise topty ptr off =
STF64 -> peekByteOff @Double ptr off
STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
STAccum{} -> error "Cannot serialise accumulators"
- STLEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
- 0 -> return Nothing
- 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
- 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
- _ -> error "Invalid tag value"
align :: Int -> Int -> Int
align a off = (off + a - 1) `div` a * a
@@ -569,10 +579,14 @@ metricsSTy (STEither a b) =
let (a1, s1) = metricsSTy a
(a2, s2) = metricsSTy b
in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
-metricsSTy (STArr _ _) = (8, 8)
+metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n)
metricsSTy (STScal sty) = case sty of
STI32 -> (4, 4)
STI64 -> (8, 8)
@@ -580,10 +594,6 @@ metricsSTy (STScal sty) = case sty of
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
-metricsSTy (STLEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -747,15 +757,17 @@ compile' env = \case
return (CELit retvar)
EConstArr _ n t (Array sh vec) -> do
- strname <- emitStruct (STArr n (STScal t))
+ (strname, bufstrname) <- emitArrStruct (STArr n (STScal t))
tldname <- genName' "carraybuf"
-- Give it a refcount of _half_ the size_t max, so that it can be
-- incremented and decremented at will and will "never" reach anything
-- where something happens
- emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++
- "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++
- ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
- return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
+ emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++
+ "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++
+ ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
+ return (CEStruct strname
+ [("buf", CEAddrOf (CELit tldname))
+ ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))])
EBuild _ n esh efun -> do
shname <- compileAssign "sh" env esh
@@ -770,7 +782,7 @@ compile' env = \case
emit $ SBlock $
pure (SVarDecl False "size_t" linivar (CELit "0"))
<> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
- (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx))))
+ (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx))))
| (ivar, dimidx) <- zip ivars [0::Int ..]]
(pure (SVarDecl True (repSTy (typeOf esh)) idxargname
(shapeTupFromLitVars n ivars))
@@ -799,7 +811,7 @@ compile' env = \case
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
@@ -845,7 +857,7 @@ compile' env = \case
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
let vecwid = 8 :: Int
ivar <- genName' "i"
@@ -909,6 +921,136 @@ compile' env = \case
EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
+ EReshape _ dim esh earg -> do
+ let STArr origDim eltty = typeOf earg
+ strname <- emitStruct (STArr dim eltty)
+
+ shname <- compileAssign "reshsh" env esh
+ arrname <- compileAssign "resharg" env earg
+
+ when emitChecks $ do
+ emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
+ printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
+ printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
+ mempty
+
+ return (CEStruct strname
+ [("buf", CEProj (CELit arrname) "buf")
+ ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+
+ EFold1InnerD1 _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+ STPair _ bty = typeOf efun
+
+ x0name <- compileAssign "foldd1x0" env ex0
+ arrname <- compileAssign "foldd1arr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
+ storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "tot"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
+ arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd1x0" Decrement t x0name
+ incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
+
+ strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
+ return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
+
+ EFold1InnerD2 _ commut efun estores ectg -> do
+ let STArr n t2 = typeOf ectg
+ STArr _ bty = typeOf estores
+
+ storesname <- compileAssign "foldd2stores" env estores
+ ctgname <- compileAssign "foldd2ctg" env ectg
+
+ zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
+ outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "acc"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
+ storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
+ ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
+ ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
+ <> ctgeltIncrStmts
+ -- we need to loop in reverse here, but we let jvar run in the
+ -- forward direction so that we can use SLoop. Note jvar is
+ -- reversed in eltidx above
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the accumulator
+ -- and the stores element. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the stores element.
+ storeseltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
+ incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
+
+ strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
+ return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
+
EConst _ t x -> return $ CELit $ compileScal True t x
EIdx0 _ e -> do
@@ -934,7 +1076,7 @@ compile' env = \case
when emitChecks $
forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
- (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]")))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
arrname ++ ".buf); return false;")
@@ -977,6 +1119,8 @@ compile' env = \case
maybe (return ()) ($ name2) mfun2
return (CELit name)
+ ERecompute _ e -> compile' env e
+
EWith _ t e1 e2 -> do
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
@@ -1000,95 +1144,7 @@ compile' env = \case
rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
- EAccum _ t prj eidx eval eacc -> do
- let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
- -- full zero array with the given zero info (for the type SMTArr n t1).
- initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
- initZeroArray n t1 v vzi = do
- shszname <- genName' "inacshsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi)
- newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi)
- emit $ SAsg v (CELit newarrName)
- forM_ (initZeroFromMemset t1) $ \f1 -> do
- ivar <- genName' "i"
- ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts
-
- -- If something needs to be done to properly initialise this type to
- -- zero after memory has already been initialised to all-zero bytes,
- -- returns an action that does so.
- -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ())
- initZeroFromMemset SMTNil = Nothing
- initZeroFromMemset (SMTPair t1 t2) =
- case (initZeroFromMemset t1, initZeroFromMemset t2) of
- (Nothing, Nothing) -> Nothing
- (mf1, mf2) -> Just $ \v vzi -> do
- forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a")
- forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b")
- initZeroFromMemset SMTLEither{} = Nothing
- initZeroFromMemset SMTMaybe{} = Nothing
- initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
- initZeroFromMemset SMTScal{} = Nothing
-
- let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroZI :: SMTy a -> String -> String -> CompM ()
- initZeroZI SMTNil _ _ = return ()
- initZeroZI (SMTPair t1 t2) v vzi = do
- initZeroZI t1 (v++".a") (vzi++".a")
- initZeroZI t2 (v++".b") (vzi++".b")
- initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZeroZI (SMTScal sty) v _ = case sty of
- STI32 -> emit $ SAsg v (CELit "0")
- STI64 -> emit $ SAsg v (CELit "0l")
- STF32 -> emit $ SAsg v (CELit "0.0f")
- STF64 -> emit $ SAsg v (CELit "0.0")
-
- let -- Initialise an uninitialised accumulation value, potentially already
- -- with the addend, potentially to zero depending on the nature of the
- -- projection.
- -- 1. If the projection indexes only through dense monoids before
- -- reaching SAPHere, the thing cannot be initialised to zero with
- -- only an AcIdx; it would need to model a zero after the addend,
- -- which is stupid and redundant. In this case, we return Left:
- -- (accumulation value) (AcIdx value) (addend value).
- -- The addend is copied, not consumed. (We can't reliably _always_
- -- consume it, so it's not worth trying to do it sometimes.)
- -- 2. Otherwise, a sparse monoid is found along the way, and we can
- -- initalise the dense prefix of the path to zero by setting the
- -- indexed-through sparse value to a sparse zero. Afterwards, the
- -- main recursion can proceed further. In this case, we return
- -- Right: (accumulation value) (AcIdx value)
- -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
- initZeroChunk :: SMTy a -> SAcPrj p a b
- -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
- (String -> String -> CompM ()) -- zero initialisation of sparse chunk
- initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
- -- reached target before the first sparse constructor
- (t1 , SAPHere ) -> Left $ \v _ addend -> do
- incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
- emit $ SAsg v (CELit addend)
- -- sparse types
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- -- dense types
- (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- f (v++".a") (i++".a")
- initZeroZI t2 (v++".b") (i++".b")
- (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
- initZeroZI t1 (v++".a") (i++".a")
- f (v++".b") (i++".b")
- (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- initZeroArray n t1 v (i++".a.b")
- linidxvar <- genName' "li"
- emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
- f (v++".buf->xs["++linidxvar++"]") (i++".b")
- where
- applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
- applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
-
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
let -- Add a value (s) into an existing accumulation value (d). If a sparse
-- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
@@ -1129,16 +1185,16 @@ compile' env = \case
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
forM_ [0 .. fromSNat n - 1] $ \j -> do
- emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
+ emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]"))
"!="
- (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
+ (CELit (d ++ ".sh[" ++ show j ++ "]")))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
"dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
d ++ ".buf" ++
- concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
", " ++ s ++ ".buf" ++
- concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
"); " ++
"return false;")
mempty
@@ -1158,67 +1214,55 @@ compile' env = \case
accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
- accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
- accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
- accumRef (SMTLEither ta tb) prj0 v i addend = do
- let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
- SAPRight prj' -> initZeroChunk tb prj'
- subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
- tagval = case prj0 of SAPLeft{} -> "1"
- SAPRight{} -> "2"
- ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
- SAPRight prj' -> accumRef tb prj' subv i addend
- case chunkres of
- Left densef -> do
- ((), stmtsSet) <- scope $ densef subv i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
- stmtsAdd -- TODO: emit check for consistency of tags?
- Right sparsef -> do
- ((), stmtsInit) <- scope $ sparsef subv i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
- forM_ stmtsAdd emit
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
- case initZeroChunk tj prj' of
- Left densef -> do
- ((), stmtsSet1) <- scope $ densef (v++".j") i addend
- ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
- stmtsAdd1
- Right sparsef -> do
- ((), stmtsInit1) <- scope $ sparsef (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i addend
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
"); " ++
"return false;")
mempty
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
nameidx <- compileAssign "acidx" env eidx
nameval <- compileAssign "acval" env eval
@@ -1232,6 +1276,9 @@ compile' env = \case
return $ CEStruct (repSTy STNil) []
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
EError _ t s -> do
let padleft len c s' = replicate (len - length s) c ++ s'
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
@@ -1245,6 +1292,7 @@ compile' env = \case
return $ CEStruct name []
EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
@@ -1303,13 +1351,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
-makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
- (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
@@ -1361,21 +1409,21 @@ toLinearIdx SZ _ _ = CELit "0"
toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
toLinearIdx (SS n) arrvar idxvar =
CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
- "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n)))))
+ "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n)))))
"+" (CELit (idxvar ++ ".b"))
-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
-- fromLinearIdx (SS n) arrvar idxvar = do
-- name <- genName
--- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
+-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]")))
-- _
data AllocMethod = Malloc | Calloc
deriving (Show)
-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
allocArray marker method nameBase rank eltty mshsz shape = do
when (length shape /= fromSNat rank) $
error "allocArray: shape does not match rank"
@@ -1390,9 +1438,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do
(CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
emit $ SVarDecl True strname arrname $ CEStruct strname
[("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
- Calloc -> CECall "calloc_instr" [nbytesExpr])]
- forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
- emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
+ Calloc -> CECall "calloc_instr" [nbytesExpr])
+ ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))]
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
when debugRefc $
emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
@@ -1403,16 +1450,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
compileShapeQuery (SS n) var =
CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
[("a", compileShapeQuery n var)
- ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
+ ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))]
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
-- | Takes a variable name for the array, not the buffer.
compileArrShapeComponents :: SNat n -> String -> [CExpr]
compileArrShapeComponents n var =
- [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
indexTupleComponents :: SNat n -> String -> [CExpr]
indexTupleComponents = \n var -> map CELit (toList (go n var))
@@ -1431,6 +1478,9 @@ shapeTupFromLitVars = \n -> go n . reverse
go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+prodExpr :: [CExpr] -> CExpr
+prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
+
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
let unary cop = return @CompM $ CECall cop [e1]
@@ -1503,7 +1553,7 @@ compileExtremum nameBase opName operator env e = do
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
@@ -1574,7 +1624,7 @@ copyForWriting topty var = case topty of
-- nesting we'd have to check the refcounts of all the nested arrays _too_;
-- let's not do that. Furthermore, no sub-arrays means that the whole thing
-- is flat, and we can just memcpy if necessary.
- SMTArr n t | not (hasArrays (fromSMTy t)) -> do
+ SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do
name <- genName
shszname <- genName' "shsz"
emit $ SVarDeclUninit toptyname name
@@ -1583,7 +1633,7 @@ copyForWriting topty var = case topty of
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
emit $ SVerbatim $
"fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
- concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
");"
emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
@@ -1594,8 +1644,7 @@ copyForWriting topty var = case topty of
in BList
[SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
+ ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
,SAsg (name ++ ".buf->refc") (CELit "1")
,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
printCExpr 0 databytes ");"])
@@ -1612,8 +1661,7 @@ copyForWriting topty var = case topty of
name <- genName
emit $ SVarDecl False toptyname name
(CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
+ emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-- put the arrays in variables to cut short the not-quite-var chain
@@ -1657,6 +1705,14 @@ zeroRefcountCheck toptyp opname topvar =
go (STEither a b) path = do
(s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
go (STMaybe a) path = do
ss <- go a (path++".j")
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
@@ -1673,14 +1729,6 @@ zeroRefcountCheck toptyp opname topvar =
return (BList [s1, s2, s3])
go STScal{} _ = empty
go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
- go (STLEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $
- SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
- s1
- (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
- s2
- mempty))
combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
combine (MaybeT a) (MaybeT b) = MaybeT $ do