diff options
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 574 |
1 files changed, 311 insertions, 263 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 6ba3a39..f2063ee 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -34,6 +34,7 @@ import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) @@ -45,6 +46,7 @@ import qualified Prelude import Array import AST import AST.Pretty (ppSTy, ppExpr) +import AST.Sparse.Types (isDense) import Compile.Exec import Data import Interpreter.Rep @@ -77,7 +79,7 @@ compile = \env expr -> do let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source ["kernel"] + lib <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type @@ -86,7 +88,7 @@ compile = \env expr -> do allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) serialiseArguments args ptr $ do - callKernelFun "kernel" lib ptr + callKernelFun lib ptr ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) when (ok /= 1) $ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) @@ -125,7 +127,7 @@ data CExpr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ (<type)<expr> + | CECast String CExpr -- ^ (<type>)<expr> deriving (Show) printStructDecl :: StructDecl -> ShowS @@ -214,23 +216,31 @@ repSTy (STScal st) = case st of STBool -> "uint8_t" repSTy t = genStructName t -genStructName :: STy t -> String -genStructName = \t -> "ty_" ++ gen t where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen (fromSMTy t) - gen (STLEither a b) = 'L' : gen a ++ gen b +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the @@ -238,60 +248,56 @@ genStructName = \t -> "ty_" ++ gen t where -- -- For accumulation it is important that for struct representations of monoid -- types, the all-zero-bytes value corresponds to the zero value of that type. -genStruct :: String -> STy t -> [StructDecl] -genStruct name topty = case topty of +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of STNil -> - [StructDecl name "" com] + TreeNode [StructDecl name "" com] [] STPair a b -> - [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] STEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - -- TODO: put shape in the main struct, not the buffer; it's constant, after all -- TODO: no buffer if n = 0 - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] STScal _ -> - [] + TreeNode [] [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] - STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] where + name = genStructName topty com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () -genStructs ty = do - let name = genStructName ty - seen <- lift $ gets (name `Set.member`) - - if seen - then pure () - else do - -- already mark this struct as generated now, so we don't generate it - -- twice (unnecessary because no recursive types, but y'know) - lift $ modify (Set.insert name) - - () <- case ty of - STNil -> pure () - STPair a b -> genStructs a >> genStructs b - STEither a b -> genStructs a >> genStructs b - STMaybe t -> genStructs t - STArr _ t -> genStructs t - STScal _ -> pure () - STAccum t -> genStructs (fromSMTy t) - STLEither a b -> genStructs a >> genStructs b - - tell (BList (genStruct name ty)) +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) data CompState = CompState { csStructs :: Set (Some STy) @@ -340,6 +346,12 @@ emitStruct ty = CompM $ do modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } return (genStructName ty) +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + emitTLD :: String -> CompM () emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } @@ -463,6 +475,15 @@ serialise topty topval ptr off k = (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k (STMaybe _, Nothing) -> do pokeByteOff ptr off (0 :: Word8) k @@ -471,19 +492,18 @@ serialise topty topval ptr off k = serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t - allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do when debugRefc $ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh - pokeShape bufptr 0 n sh - pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63) + pokeByteOff @Word64 bufptr 0 (2 ^ 63) - let off1 = fromSNat n * 8 + 8 - loop i + let loop i | i == shapeSize sh = k | otherwise = - serialise t (vec V.! i) bufptr (off1 + i * eltsz) $ + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ loop (i+1) loop 0 (STScal sty, x) -> case sty of @@ -493,15 +513,6 @@ serialise topty topval ptr off k = STF64 -> pokeByteOff ptr off (x :: Double) >> k STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k (STAccum{}, _) -> error "Cannot serialise accumulators" - (STLEither _ _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STLEither a _, Just (Left x)) -> do - pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STLEither _ b, Just (Right y)) -> do - pokeByteOff ptr off (2 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k -- | Assumes that this is called at the correct alignment. deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -518,6 +529,13 @@ deserialise topty ptr off = if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" STMaybe t -> do tag <- peekByteOff @Word8 ptr off if tag == 0 @@ -525,13 +543,12 @@ deserialise topty ptr off = else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape bufptr 0 n - refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 when debugRefc $ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let off1 = 8 * fromSNat n + 8 - eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) when (refc < 2 ^ 62) $ free bufptr return arr STScal sty -> case sty of @@ -541,13 +558,6 @@ deserialise topty ptr off = STF64 -> peekByteOff @Double ptr off STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off STAccum{} -> error "Cannot serialise accumulators" - STLEither a b -> do - tag <- peekByteOff @Word8 ptr off - case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) - 0 -> return Nothing - 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) - 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) - _ -> error "Invalid tag value" align :: Int -> Int -> Int align a off = (off + a - 1) `div` a * a @@ -569,10 +579,14 @@ metricsSTy (STEither a b) = let (a1, s1) = metricsSTy a (a2, s2) = metricsSTy b in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr _ _) = (8, 8) +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) metricsSTy (STScal sty) = case sty of STI32 -> (4, 4) STI64 -> (8, 8) @@ -580,10 +594,6 @@ metricsSTy (STScal sty) = case sty of STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t metricsSTy (STAccum t) = metricsSTy (fromSMTy t) -metricsSTy (STLEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -747,15 +757,17 @@ compile' env = \case return (CELit retvar) EConstArr _ n t (Array sh vec) -> do - strname <- emitStruct (STArr n (STScal t)) + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" -- Give it a refcount of _half_ the size_t max, so that it can be -- incremented and decremented at will and will "never" reach anything -- where something happens - emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++ - "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ - ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) EBuild _ n esh efun -> do shname <- compileAssign "sh" env esh @@ -770,7 +782,7 @@ compile' env = \case emit $ SBlock $ pure (SVarDecl False "size_t" linivar (CELit "0")) <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx)))) + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) @@ -799,7 +811,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name @@ -845,7 +857,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) let vecwid = 8 :: Int ivar <- genName' "i" @@ -909,6 +921,136 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -934,7 +1076,7 @@ compile' env = \case when emitChecks $ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ arrname ++ ".buf); return false;") @@ -977,6 +1119,8 @@ compile' env = \case maybe (return ()) ($ name2) mfun2 return (CELit name) + ERecompute _ e -> compile' env e + EWith _ t e1 e2 -> do actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 @@ -1000,95 +1144,7 @@ compile' env = \case rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - EAccum _ t prj eidx eval eacc -> do - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a - -- full zero array with the given zero info (for the type SMTArr n t1). - initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () - initZeroArray n t1 v vzi = do - shszname <- genName' "inacshsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) - newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) - emit $ SAsg v (CELit newarrName) - forM_ (initZeroFromMemset t1) $ \f1 -> do - ivar <- genName' "i" - ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts - - -- If something needs to be done to properly initialise this type to - -- zero after memory has already been initialised to all-zero bytes, - -- returns an action that does so. - -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) - initZeroFromMemset SMTNil = Nothing - initZeroFromMemset (SMTPair t1 t2) = - case (initZeroFromMemset t1, initZeroFromMemset t2) of - (Nothing, Nothing) -> Nothing - (mf1, mf2) -> Just $ \v vzi -> do - forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") - forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") - initZeroFromMemset SMTLEither{} = Nothing - initZeroFromMemset SMTMaybe{} = Nothing - initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi - initZeroFromMemset SMTScal{} = Nothing - - let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroZI :: SMTy a -> String -> String -> CompM () - initZeroZI SMTNil _ _ = return () - initZeroZI (SMTPair t1 t2) v vzi = do - initZeroZI t1 (v++".a") (vzi++".a") - initZeroZI t2 (v++".b") (vzi++".b") - initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZeroZI (SMTScal sty) v _ = case sty of - STI32 -> emit $ SAsg v (CELit "0") - STI64 -> emit $ SAsg v (CELit "0l") - STF32 -> emit $ SAsg v (CELit "0.0f") - STF64 -> emit $ SAsg v (CELit "0.0") - - let -- Initialise an uninitialised accumulation value, potentially already - -- with the addend, potentially to zero depending on the nature of the - -- projection. - -- 1. If the projection indexes only through dense monoids before - -- reaching SAPHere, the thing cannot be initialised to zero with - -- only an AcIdx; it would need to model a zero after the addend, - -- which is stupid and redundant. In this case, we return Left: - -- (accumulation value) (AcIdx value) (addend value). - -- The addend is copied, not consumed. (We can't reliably _always_ - -- consume it, so it's not worth trying to do it sometimes.) - -- 2. Otherwise, a sparse monoid is found along the way, and we can - -- initalise the dense prefix of the path to zero by setting the - -- indexed-through sparse value to a sparse zero. Afterwards, the - -- main recursion can proceed further. In this case, we return - -- Right: (accumulation value) (AcIdx value) - -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) - initZeroChunk :: SMTy a -> SAcPrj p a b - -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend - (String -> String -> CompM ()) -- zero initialisation of sparse chunk - initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of - -- reached target before the first sparse constructor - (t1 , SAPHere ) -> Left $ \v _ addend -> do - incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend - emit $ SAsg v (CELit addend) - -- sparse types - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - -- dense types - (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - f (v++".a") (i++".a") - initZeroZI t2 (v++".b") (i++".b") - (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do - initZeroZI t1 (v++".a") (i++".a") - f (v++".b") (i++".b") - (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - initZeroArray n t1 v (i++".a.b") - linidxvar <- genName' "li" - emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) - f (v++".buf->xs["++linidxvar++"]") (i++".b") - where - applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i - applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i - + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do let -- Add a value (s) into an existing accumulation value (d). If a sparse -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () @@ -1129,16 +1185,16 @@ compile' env = \case when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ [0 .. fromSNat n - 1] $ \j -> do - emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) "!=" - (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (CELit (d ++ ".sh[" ++ show j ++ "]"))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ d ++ ".buf" ++ - concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ ", " ++ s ++ ".buf" ++ - concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ "); " ++ "return false;") mempty @@ -1158,67 +1214,55 @@ compile' env = \case accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend - accumRef (SMTLEither ta tb) prj0 v i addend = do - let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' - SAPRight prj' -> initZeroChunk tb prj' - subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") - tagval = case prj0 of SAPLeft{} -> "1" - SAPRight{} -> "2" - ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend - SAPRight prj' -> accumRef tb prj' subv i addend - case chunkres of - Left densef -> do - ((), stmtsSet) <- scope $ densef subv i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) - stmtsAdd -- TODO: emit check for consistency of tags? - Right sparsef -> do - ((), stmtsInit) <- scope $ sparsef subv i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty - forM_ stmtsAdd emit + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - case initZeroChunk tj prj' of - Left densef -> do - ((), stmtsSet1) <- scope $ densef (v++".j") i addend - ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) - stmtsAdd1 - Right sparsef -> do - ((), stmtsInit1) <- scope $ sparsef (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i addend + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend nameidx <- compileAssign "acidx" env eidx nameval <- compileAssign "acval" env eval @@ -1232,6 +1276,9 @@ compile' env = \case return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] @@ -1245,6 +1292,7 @@ compile' env = \case return $ CEStruct name [] EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" @@ -1303,13 +1351,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop - (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1361,21 +1409,21 @@ toLinearIdx SZ _ _ = CELit "0" toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") toLinearIdx (SS n) arrvar idxvar = CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n))))) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) "+" (CELit (idxvar ++ ".b")) -- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr -- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] -- fromLinearIdx (SS n) arrvar idxvar = do -- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) -- _ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" @@ -1390,9 +1438,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr])] - forM_ (zip shape [0::Int ..]) $ \(dim, i) -> - emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") when debugRefc $ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" @@ -1403,16 +1450,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] compileShapeQuery (SS n) var = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) -- | Takes a variable name for the array, not the buffer. compileArrShapeComponents :: SNat n -> String -> [CExpr] compileArrShapeComponents n var = - [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1431,6 +1478,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] @@ -1503,7 +1553,7 @@ compileExtremum nameBase opName operator env e = do lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" @@ -1574,7 +1624,7 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (hasArrays (fromSMTy t)) -> do + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" emit $ SVarDeclUninit toptyname name @@ -1583,7 +1633,7 @@ copyForWriting topty var = case topty of let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" emit $ SVerbatim $ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ ");" emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) @@ -1594,8 +1644,7 @@ copyForWriting topty var = case topty of in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ printCExpr 0 databytes ");"]) @@ -1612,8 +1661,7 @@ copyForWriting topty var = case topty of name <- genName emit $ SVarDecl False toptyname name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain @@ -1657,6 +1705,14 @@ zeroRefcountCheck toptyp opname topvar = go (STEither a b) path = do (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) go (STMaybe a) path = do ss <- go a (path++".j") return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty @@ -1673,14 +1729,6 @@ zeroRefcountCheck toptyp opname topvar = return (BList [s1, s2, s3]) go STScal{} _ = empty go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" - go (STLEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ - SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) - s1 - (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) - s2 - mempty)) combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) combine (MaybeT a) (MaybeT b) = MaybeT $ do |
