diff options
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 1063 |
1 files changed, 750 insertions, 313 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index d9cfd95..a5c4fb7 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,13 +1,19 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} module Compile (compile) where +import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Maybe import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS import Data.Bifunctor (first) @@ -16,6 +22,7 @@ import Data.Foldable (toList) import Data.Functor.Const import qualified Data.Functor.Product as Product import Data.Functor.Product (Product) +import Data.IORef import Data.List (foldl1', intersperse, intercalate) import qualified Data.Map.Strict as Map import Data.Maybe (fromMaybe) @@ -24,52 +31,67 @@ import Data.Set (Set) import Data.Some import qualified Data.Vector as V import Foreign +import GHC.Exts (int2Word#, addr2Int#) +import GHC.Num (integerFromWord#) +import GHC.Ptr (Ptr(..)) import Numeric (showHex) import System.IO (hPutStrLn, stderr) +import System.IO.Error (mkIOError, userErrorType) +import System.IO.Unsafe (unsafePerformIO) import Prelude hiding ((^)) import qualified Prelude import Array import AST -import AST.Pretty (ppTy) +import AST.Pretty (ppSTy, ppExpr) +import AST.Sparse.Types (isDense) import Compile.Exec import Data import Interpreter.Rep - - -{- -:m *Example Compile AST.UnMonoid -:seti -XOverloadedLabels -XGADTs -let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i :: Double) in (($ SCons (Value array) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ #x ! pair nil (round_ (#x ! pair nil 3)) -(($ SNil) =<<) $ compile knownEnv $ fromNamed $ body $ build2 5 3 (#i :-> #j :-> 10 * #i + #j) --} +import qualified Util.IdGen as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). +-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places + -debug :: Bool -debug = toEnum 0 +-- | Print the compiled AST +debugPrintAST :: Bool; debugPrintAST = toEnum 0 +-- | Print the generated C source +debugCSource :: Bool; debugCSource = toEnum 0 +-- | Print extra stuff about reference counts of arrays +debugRefc :: Bool; debugRefc = toEnum 0 +-- | Print some shape-related information +debugShapes :: Bool; debugShapes = toEnum 0 +-- | Print information on allocation +debugAllocs :: Bool; debugAllocs = toEnum 0 +-- | Emit extra C code that checks stuff +emitChecks :: Bool; emitChecks = toEnum 0 compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) compile = \env expr -> do - let source = compileToString env expr - when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" - lib <- buildKernel source ["kernel"] + codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) - let arg_metrics = reverse (unSList metricsSTy env) - (arg_offsets, result_offset) = computeStructOffsets arg_metrics - result_type = typeOf expr + let (source, offsets) = compileToString codeID env expr + when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" + when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" + lib <- buildKernel source "kernel" + + let result_type = typeOf expr result_size = sizeofSTy result_type return $ \val -> do - allocaBytes (result_offset + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) serialiseArguments args ptr $ do - callKernelFun "kernel" lib ptr - deserialise result_type ptr result_offset + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) where serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -184,62 +206,68 @@ printCExpr d = \case ,("/", (7, (7, 8))) ,("%", (7, (7, 8)))] -repTy :: Ty -> String -repTy (TScal st) = case st of - TI32 -> "int32_t" - TI64 -> "int64_t" - TF32 -> "float" - TF64 -> "double" - TBool -> "uint8_t" -repTy t = genStructName t - repSTy :: STy t -> String -repSTy = repTy . unSTy - -genStructName :: Ty -> String +repSTy (STScal st) = case st of + STI32 -> "int32_t" + STI64 -> "int64_t" + STF32 -> "float" + STF64 -> "double" + STBool -> "uint8_t" +repSTy t = genStructName t + +genStructName :: STy t -> String genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. - gen :: Ty -> String - gen TNil = "n" - gen (TPair a b) = 'P' : gen a ++ gen b - gen (TEither a b) = 'E' : gen a ++ gen b - gen (TMaybe t) = 'M' : gen t - gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t - gen (TScal st) = case st of - TI32 -> "i" - TI64 -> "j" - TF32 -> "f" - TF64 -> "d" - TBool -> "b" - gen (TAccum t) = 'C' : gen t + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the -- types in the C translation. -genStruct :: String -> Ty -> [StructDecl] +-- +-- For accumulation it is important that for struct representations of monoid +-- types, the all-zero-bytes value corresponds to the zero value of that type. +genStruct :: String -> STy t -> [StructDecl] genStruct name topty = case topty of - TNil -> + STNil -> [StructDecl name "" com] - TPair a b -> - [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com] - TEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com] - TMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com] - TArr n t -> + STPair a b -> + [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + STEither a b -> -- 0 -> l, 1 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + STMaybe t -> -- 0 -> nothing, 1 -> just + [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") "" + -- TODO: put shape in the main struct, not the buffer; it's constant, after all + -- TODO: no buffer if n = 0 + [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" ,StructDecl name (name ++ "_buf *buf;") com] - TScal _ -> + STScal _ -> [] - TAccum t -> - [StructDecl name (repTy t ++ " ac;") com] + STAccum t -> + [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] where - com = ppTy 0 topty + com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) () +genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () genStructs ty = do let name = genStructName ty seen <- lift $ gets (name `Set.member`) @@ -251,95 +279,163 @@ genStructs ty = do -- twice (unnecessary because no recursive types, but y'know) lift $ modify (Set.insert name) - case ty of - TNil -> pure () - TPair a b -> genStructs a >> genStructs b - TEither a b -> genStructs a >> genStructs b - TMaybe t -> genStructs t - TArr _ t -> genStructs t - TScal _ -> pure () - TAccum t -> genStructs t + () <- case ty of + STNil -> pure () + STPair a b -> genStructs a >> genStructs b + STEither a b -> genStructs a >> genStructs b + STLEither a b -> genStructs a >> genStructs b + STMaybe t -> genStructs t + STArr _ t -> genStructs t + STScal _ -> pure () + STAccum t -> genStructs (fromSMTy t) tell (BList (genStruct name ty)) -genAllStructs :: Foldable t => t Ty -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty +genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] +genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty data CompState = CompState - { csStructs :: Set Ty + { csStructs :: Set (Some STy) , csTopLevelDecls :: Bag String , csStmts :: Bag Stmt , csNextId :: Int } deriving (Show) -type CompM a = State CompState a +newtype CompM a = CompM (State CompState a) + deriving newtype (Functor, Applicative, Monad) + +runCompM :: CompM a -> (a, CompState) +runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) -genId :: CompM Int -genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +class Monad m => MonadNameGen m where genId :: m Int +instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +instance MonadNameGen IdGen.IdGen where genId = IdGen.genId +instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) -genName' :: String -> CompM String +genName' :: MonadNameGen m => String -> m String +genName' "" = genName genName' prefix = (prefix ++) . show <$> genId -genName :: CompM String +genName :: MonadNameGen m => m String genName = genName' "x" +onlyIdGen :: IdGen.IdGen a -> CompM a +onlyIdGen m = CompM $ do + i1 <- gets csNextId + let (res, i2) = IdGen.runIdGen' i1 m + modify (\s -> s { csNextId = i2 }) + return res + emit :: Stmt -> CompM () -emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } +emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } -scope :: CompM a -> CompM (a, [Stmt]) +scope :: CompM a -> CompM (a, Bag Stmt) scope m = do - stmts <- state $ \s -> (csStmts s, s { csStmts = mempty }) + stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) res <- m - innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts }) - return (res, toList innerStmts) + innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) + return (res, innerStmts) emitStruct :: STy t -> CompM String -emitStruct ty = do - let ty' = unSTy ty - modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } - return (genStructName ty') +emitStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty) emitTLD :: String -> CompM () -emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } +emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) -compileToString :: SList STy env -> Ex env t -> String -compileToString env expr = +data KernelOffsets = KernelOffsets + { koArgOffsets :: [Int] -- ^ the function arguments + , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred + , koResultOffset :: Int -- ^ the function result + } + +compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets) +compileToString codeID env expr = let args = nameEnv env - (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1) - structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) + (res, s) = runCompM (compile' args expr) + structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env)) (arg_pairs, arg_metrics) = unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) (slistZip env args)) - (arg_offsets, result_offset') = computeStructOffsets arg_metrics - result_offset = align (alignmentSTy (typeOf expr)) result_offset' - in ($ "") $ compose + (arg_offsets, okres_offset) = computeStructOffsets arg_metrics + result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) + + offsets = KernelOffsets + { koArgOffsets = arg_offsets + , koOkResOffset = okres_offset + , koResultOffset = result_offset } + in (,offsets) . ($ "") $ compose [showString "#include <stdio.h>\n" ,showString "#include <stdint.h>\n" + ,showString "#include <stdbool.h>\n" + ,showString "#include <inttypes.h>\n" ,showString "#include <stdlib.h>\n" + ,showString "#include <string.h>\n" ,showString "#include <math.h>\n\n" - ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs + -- PRint-tag + ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n" + + ,compose [printStructDecl sd . showString "\n" | sd <- structs] ,showString "\n" + + -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative" + ,showString "static void* malloc_instr_fun(size_t n, int line) {\n" + ,showString " void *ptr = malloc(n);\n" + ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n" + else id + ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n" + else id + ,showString " return ptr;\n" + ,showString "}\n" + ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" + ,showString "static void* calloc_instr_fun(size_t n, int line) {\n" + ,showString " void *ptr = calloc(n, 1);\n" + ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n" + else id + ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n" + else id + ,showString " return ptr;\n" + ,showString "}\n" + ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" + ,showString "static void free_instr(void *ptr) {\n" + ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n" + else id + ,showString " free(ptr);\n" + ,showString "}\n\n" + ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] + ,showString $ - "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++ - intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ + "static bool typed_kernel(" ++ + repSTy (typeOf expr) ++ " *output" ++ + concatMap (", " ++) + (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ ") {\n" - ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) - ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n" + ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] + ,showString " *output = " . printCExpr 0 res . showString ";\n" + ,showString " return true;\n" + ,showString "}\n\n" + ,showString "void kernel(void *data) {\n" -- Some code here assumes that we're on a 64-bit system, so let's check that - ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" - ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ - concat (map (\((arg, typ), off, idx) -> - "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ (if idx < length arg_pairs - 1 then "," else "") - ++ " // " ++ arg) - (zip3 arg_pairs arg_offsets [0::Int ..])) ++ + ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n" + ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n" + else id + ,showString $ " const bool success = typed_kernel(" ++ + "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ + concat (map (\((arg, typ), off) -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + (zip arg_pairs arg_offsets)) ++ "\n );\n" + ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" + ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" + else id ,showString "}\n"] -- | Takes list of metrics (alignment, sizeof). @@ -363,11 +459,20 @@ serialise topty topval ptr off k = serialise a x ptr off $ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k (STEither a _, Left x) -> do - pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b}) + pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) serialise a x ptr (off + alignmentSTy topty) k (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k (STMaybe _, Nothing) -> do pokeByteOff ptr off (0 :: Word8) k @@ -377,6 +482,8 @@ serialise topty topval ptr off k = (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + when debugRefc $ + hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr pokeShape bufptr 0 n sh @@ -409,9 +516,16 @@ deserialise topty ptr off = return (x, y) STEither a b -> do tag <- peekByteOff @Word8 ptr off - if tag == 0 -- alignment of (a + b) is alignment of (union {a b}) + if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" STMaybe t -> do tag <- peekByteOff @Word8 ptr off if tag == 0 @@ -421,6 +535,8 @@ deserialise topty ptr off = bufptr <- peekByteOff @(Ptr ()) ptr off sh <- peekShape bufptr 0 n refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + when debugRefc $ + hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc let off1 = 8 * fromSNat n + 8 eltsz = sizeofSTy t arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) @@ -454,6 +570,10 @@ metricsSTy (STEither a b) = let (a1, s1) = metricsSTy a (a2, s2) = metricsSTy b in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned @@ -464,7 +584,7 @@ metricsSTy (STScal sty) = case sty of STF32 -> (4, 4) STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t -metricsSTy (STAccum t) = metricsSTy t +metricsSTy (STAccum t) = metricsSTy (fromSMTy t) pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -486,15 +606,13 @@ compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case EVar _ t i -> do let Const var = slistIdx env i - incrementVarAlways Increment t var + incrementVarAlways "var" Increment t var return $ CELit var ELet _ rhs body -> do - e <- compile' env rhs - var <- genName - emit $ SVarDecl True (repSTy (typeOf rhs)) var e + var <- compileAssign "" env rhs rete <- compile' (Const var `SCons` env) body - incrementVarAlways Decrement (typeOf rhs) var + incrementVarAlways "let" Decrement (typeOf rhs) var return rete EPair _ a b -> do @@ -506,7 +624,7 @@ compile' env = \case EFst _ e -> do let STPair _ t2 = typeOf e e' <- compile' env e - case incrementVar Decrement t2 of + case incrementVar "fst" Decrement t2 of Nothing -> return $ CEProj e' "a" Just f -> do var <- genName emit $ SVarDecl True (repSTy (typeOf e)) var e' @@ -516,7 +634,7 @@ compile' env = \case ESnd _ e -> do let STPair t1 _ = typeOf e e' <- compile' env e - case incrementVar Decrement t1 of + case incrementVar "snd" Decrement t1 of Nothing -> return $ CEProj e' "b" Just f -> do var <- genName emit $ SVarDecl True (repSTy (typeOf e)) var e' @@ -544,8 +662,8 @@ compile' env = \case retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SIf e1 - (BList stmts2 <> pure (SAsg retvar e2)) - (BList stmts3 <> pure (SAsg retvar e3)) + (stmts2 <> pure (SAsg retvar e2)) + (stmts3 <> pure (SAsg retvar e3)) return (CELit retvar) ECase _ e a b -> do @@ -555,17 +673,17 @@ compile' env = \case -- I know those are not variable names, but it's fine, probably (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b - ((), stmtsRel1) <- scope $ incrementVarAlways Decrement t1 (var ++ ".l") - ((), stmtsRel2) <- scope $ incrementVarAlways Decrement t2 (var ++ ".r") + ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r") retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (BList stmts2 - <> BList stmtsRel1 + (stmts2 + <> stmtsRel1 <> pure (SAsg retvar e2)) - (BList stmts3 - <> BList stmtsRel2 + (stmts3 + <> stmtsRel2 <> pure (SAsg retvar e3)))) return (CELit retvar) @@ -584,18 +702,51 @@ compile' env = \case var <- genName (e2, stmts2) <- scope $ compile' env a (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b - ((), stmtsRel) <- scope $ incrementVarAlways Decrement t (var ++ ".j") + ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j") retvar <- genName emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (BList stmts2 + (stmts2 <> pure (SAsg retvar e2)) - (BList stmts3 - <> BList stmtsRel + (stmts3 + <> stmtsRel <> pure (SAsg retvar e3)))) return (CELit retvar) + ELNil _ t1 t2 -> do + name <- emitStruct (STLEither t1 t2) + return $ CEStruct name [("tag", CELit "0")] + + ELInl _ t e -> do + name <- emitStruct (STLEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("l", e1)] + + ELInr _ t e -> do + name <- emitStruct (STLEither t (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "2"), ("r", e1)] + + ELCase _ e a b c -> do + let STLEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b + (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c + ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 <> pure (SAsg retvar e2)) + (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) + (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) + (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) + return (CELit retvar) + EConstArr _ n t (Array sh vec) -> do strname <- emitStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" @@ -608,12 +759,9 @@ compile' env = \case return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) EBuild _ n esh efun -> do - shname <- genName' "sh" - emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh - shsizename <- genName' "shsz" - emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname) + shname <- compileAssign "sh" env esh - arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname) + arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname) idxargname <- genName' "ix" (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun @@ -627,40 +775,98 @@ compile' env = \case | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) - <> BList funstmts + <> funstmts <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval)) return (CELit arrname) - -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) + EFold1Inner _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr - ESum1Inner _ e -> do - let STArr (SS n) t = typeOf e - e' <- compile' env e - argname <- genName' "sumarg" - emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e' + -- let vecwid = case commut of Commut -> 8 :: Int + -- Noncommut -> 1 + + x0name <- compileAssign "foldx0" env ex0 + arrname <- compileAssign "foldarr" env earr + + zeroRefcountCheck (typeOf earr) "fold1i" arrname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) + + resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + -- kvar <- if vecwid > 1 then genName' "k" else return "" + + accvar <- genName' "tot" + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ + ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" + (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SAsg accvar funres)) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldx0" Decrement t x0name + incrementVarAlways "foldarr" Decrement (typeOf earr) arrname + + return (CELit resname) + + ESum1Inner _ e -> do + let STArr (SS n) t = typeOf e + argname <- compileAssign "sumarg" env e + + zeroRefcountCheck (typeOf e) "sum1i" argname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "sumres" n t (CELit shszname) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + let vecwid = 8 :: Int ivar <- genName' "i" jvar <- genName' "j" + kvar <- genName' "k" accvar <- genName' "tot" + let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList -- we have ScalIsNumeric, so it has 0 and (+) in C - [SVarDecl False (repSTy t) accvar (CELit "0") - ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" - ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)] + [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};" + ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $ + pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];" + ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))] + + incrementVarAlways "sum" Decrement (typeOf e) argname return (CELit resname) @@ -670,25 +876,24 @@ compile' env = \case strname <- emitStruct typ name <- genName emit $ SVarDecl True strname name (CEStruct strname - [("buf", CECall "malloc" [CELit (show (8 + sizeofSTy (typeOf e)))])]) + [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])]) emit $ SAsg (name ++ ".buf->refc") (CELit "1") emit $ SAsg (name ++ ".buf->xs[0]") e' return (CELit name) EReplicate1Inner _ elen earg -> do let STArr n t = typeOf earg - lenname <- genName' "replen" - emit . SVarDecl True (repSTy tIx) lenname =<< compile' env elen - argname <- genName' "reparg" - emit . SVarDecl True (repSTy (typeOf earg)) argname =<< compile' env earg + lenname <- compileAssign "replen" env elen + argname <- compileAssign "reparg" env earg + + zeroRefcountCheck (typeOf earg) "replicate1i" argname shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "rep" (SS n) t - (CEBinop (CELit shszname) "*" (CELit lenname)) - ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] - ++ [CELit lenname]) + resname <- allocArray "repl1i" Malloc "rep" (SS n) t + (Just (CEBinop (CELit shszname) "*" (CELit lenname))) + (compileArrShapeComponents n argname ++ [CELit lenname]) ivar <- genName' "i" jvar <- genName' "j" @@ -697,6 +902,8 @@ compile' env = \case pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]") (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]")) + incrementVarAlways "repl1i" Decrement (typeOf earg) argname + return (CELit resname) EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e @@ -707,37 +914,48 @@ compile' env = \case EIdx0 _ e -> do let STArr _ t = typeOf e - e' <- compile' env e - arrname <- genName - emit $ SVarDecl True (repSTy (STArr SZ t)) arrname e' + arrname <- compileAssign "" env e + zeroRefcountCheck (typeOf e) "idx0" arrname name <- genName emit $ SVarDecl True (repSTy t) name (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) - incrementVarAlways Decrement (STArr SZ t) arrname + incrementVarAlways "idx0" Decrement (STArr SZ t) arrname return (CELit name) -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) EIdx _ earr eidx -> do let STArr n t = typeOf earr - arrname <- genName' "ixarr" - idxname <- genName' "ixix" - emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr - when (fromSNat n > 0) $ emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx + arrname <- compileAssign "ixarr" env earr + zeroRefcountCheck (typeOf earr) "idx" arrname + idxname <- if fromSNat n > 0 -- prevent an unused-varable warning + then compileAssign "ixix" env eidx + else return "" -- won't be used in this case + + when emitChecks $ + forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> + emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ + arrname ++ ".buf); return false;") + mempty + resname <- genName' "ixres" emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname)) - incrementVarAlways Decrement (STArr n t) arrname + incrementVarAlways "idxelt" Increment t resname + incrementVarAlways "idx" Decrement (STArr n t) arrname return (CELit resname) EShape _ e -> do let STArr n _ = typeOf e t = tTup (sreplicate n tIx) _ <- emitStruct t - name <- genName - emit . SVarDecl True (repSTy (typeOf e)) name =<< compile' env e + name <- compileAssign "" env e + zeroRefcountCheck (typeOf e) "shape" name resname <- genName emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) - incrementVarAlways Decrement (typeOf e) name + incrementVarAlways "shape" Decrement (typeOf e) name return (CELit resname) EOp _ op (EPair _ e1 e2) -> do @@ -749,123 +967,224 @@ compile' env = \case e' <- compile' env e compileOpGeneral op e' - ECustom _ t1 t2 _ earg _ _ e1 e2 -> do - e1' <- compile' env e1 - name1 <- genName - emit $ SVarDecl True (repSTy t1) name1 e1' - e2' <- compile' env e2 - name2 <- genName - emit $ SVarDecl True (repSTy t2) name2 e2' - compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg + ECustom _ _ _ _ earg _ _ e1 e2 -> do + name1 <- compileAssign "" env e1 + name2 <- compileAssign "" env e2 + case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of + (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg + (mfun1, mfun2) -> do + name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg + maybe (return ()) ($ name1) mfun1 + maybe (return ()) ($ name2) mfun2 + return (CELit name) + + ERecompute _ e -> compile' env e EWith _ t e1 e2 -> do - e1' <- compile' env e1 - name1 <- genName - emit $ SVarDecl True (repSTy (typeOf e1)) name1 e1' + actyname <- emitStruct (STAccum t) + name1 <- compileAssign "" env e1 + + zeroRefcountCheck (typeOf e1) "with" name1 + emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" mcopy <- copyForWriting t name1 accname <- genName' "accum" - emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy) + emit $ SVarDecl False actyname accname + (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) + emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 - return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)] - - EAccum _ t prj eidx eval eacc -> do - eidx' <- compile' env eidx - nameidx <- genName - emit $ SVarDecl True (repSTy (typeOf eidx)) nameidx eidx' - - eval' <- compile' env eval - nameval <- genName - emit $ SVarDecl True (repSTy (typeOf eval)) nameval eval' - - eacc' <- compile' env eacc - nameacc <- genName - emit $ SVarDecl True (repSTy (typeOf eacc)) nameacc eacc' - - let accumRef :: STy a -> SAcPrj p a b -> String -> String -> String - accumRef _ SAPHere v _ = v - accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i - accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i - accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i - accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i - accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i - accumRef (STArr n t') (SAPArrIdx prj' _) v i = - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") - - let add :: STy a -> String -> String -> CompM () - add STNil _ _ = return () - add (STPair t1 t2) d s = do + resname <- genName' "acret" + emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) + emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" + + rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) + return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] + + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do + let -- Add a value (s) into an existing accumulation value (d). If a sparse + -- component of d is encountered, s is copied there. + add :: SMTy a -> String -> String -> CompM () + add SMTNil _ _ = return () + add (SMTPair t1 t2) d s = do add t1 (d++".a") (s++".a") add t2 (d++".b") (s++".b") - add (STEither t1 t2) d s = do + add (SMTLEither t1 t2) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") - emit $ SAsg (d++".tag") (CELit (s++".tag")) - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0")) - (BList stmts1) (BList stmts2) - add (STMaybe t1) d s = do + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + ((if emitChecks + then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) + "&&" + (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ + "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ + "return false;") + mempty) + else mempty) + -- note: s may have tag 0 + <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + stmts1 + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) + stmts2 mempty)))) + add (SMTMaybe t1) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") - emit $ SAsg (d++".tag") (CELit (s++".tag")) - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (BList stmts1) mempty - add (STArr n t1) d s = do + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) + add (SMTArr n t1) d s = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ [0 .. fromSNat n - 1] $ \j -> do + emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + "!=" + (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ + "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ + d ++ ".buf" ++ + concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + ", " ++ s ++ ".buf" ++ + concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + "); " ++ + "return false;") + mempty + shsizename <- genName' "acshsz" - emit $ SVarDecl True "size_t" shsizename (compileShapeSize n (s++".a.b")) + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) ivar <- genName' "i" ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $ - BList stmts1 - add (STScal sty) d s = case sty of - STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STBool -> error "Compile: accumulator add on booleans" - add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" - - let dest = accumRef t prj (nameacc++".ac") nameidx - add (typeOf eval) dest nameval + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1 + add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + + let -- | Dereference an accumulation value and add a given value to that + -- position. Sparse components encountered along the way are + -- initialised before proceeding downwards. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () + accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend + + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend + + accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend + + accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do + let a .||. b = CEBinop a "||" b + emit $ SIf (CEBinop ixcomp "<" (CELit "0") + .||. + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ + "); " ++ + "return false;") + mempty + + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend + + nameidx <- compileAssign "acidx" env eidx + nameval <- compileAssign "acval" env eval + nameacc <- compileAssign "acac" env eacc + + emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" + accumRef t prj (nameacc++".buf->ac") nameidx nameval + emit $ SVerbatim $ "// compile EAccum end" + + incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") | otherwise -> [c] - emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ escape s ++ "); exit(1);" + emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;" case t of STScal _ -> return (CELit "0") _ -> do name <- emitStruct t return $ CEStruct name [] - EZero{} -> error "Compile: monoid operations should have been eliminated" - EPlus{} -> error "Compile: monoid operations should have been eliminated" - EOneHot{} -> error "Compile: monoid operations should have been eliminated" + EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EFold1Inner{} -> error "Compile: not implemented: EFold1Inner" EIdx1{} -> error "Compile: not implemented: EIdx1" +compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String +compileAssign prefix env e = do + e' <- compile' env e + case e' of + CELit name -> return name + _ -> do + name <- genName' prefix + emit $ SVarDecl True (repSTy (typeOf e)) name e' + return name + data Increment = Increment | Decrement deriving (Show) -- | Increment reference counts in the components of the given variable. -incrementVar :: Increment -> STy a -> Maybe (String -> CompM ()) -incrementVar inc ty = +incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ()) +incrementVar marker inc ty = let tree = makeArrayTree ty in case tree of ATNoop -> Nothing - _ -> Just $ \var -> incrementVar' inc var tree + _ -> Just $ \var -> incrementVar' marker inc var tree -incrementVarAlways :: Increment -> STy a -> String -> CompM () -incrementVarAlways inc ty var = maybe (pure ()) ($ var) (incrementVar inc ty) +incrementVarAlways :: String -> Increment -> STy a -> String -> CompM () +incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty) -data ArrayTree = ATArray -- ^ we've arrived at an array we need to decrement the refcount of +data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array) | ATNoop -- ^ don't do anything here | ATProj String ArrayTree -- ^ descend one field deeper | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 | ATBoth ArrayTree ArrayTree -- ^ do both these paths smartATProj :: String -> ArrayTree -> ArrayTree @@ -876,6 +1195,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree smartATCondTag ATNoop ATNoop = ATNoop smartATCondTag t t' = ATCondTag t t' +smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree +smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop +smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 + smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree smartATBoth ATNoop t = t smartATBoth t ATNoop = t @@ -887,24 +1210,58 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) -makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t) -makeArrayTree (STArr _ _) = ATArray +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) +makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop -incrementVar' :: Increment -> String -> ArrayTree -> CompM () -incrementVar' inc path ATArray = +incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () +incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = case inc of - Increment -> emit $ SVerbatim (path ++ ".buf->refc++;") - Decrement -> - emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);" -incrementVar' _ _ ATNoop = pure () -incrementVar' inc path (ATProj field t) = incrementVar' inc (path ++ "." ++ field) t -incrementVar' inc path (ATCondTag t1 t2) = do - ((), stmts1) <- scope $ incrementVar' inc path t1 - ((), stmts2) <- scope $ incrementVar' inc path t2 - emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2) -incrementVar' inc path (ATBoth t1 t2) = incrementVar' inc path t1 >> incrementVar' inc path t2 + Increment -> do + emit $ SVerbatim (path ++ ".buf->refc++;") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" + Decrement -> do + case incrementVar (marker++".elt") Decrement eltty of + Nothing -> + if debugRefc + then do + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" + else do + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);" + Just f -> do + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + shszvar <- genName' "frshsz" + ivar <- genName' "i" + ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") + emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0")) + (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path) + ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $ + eltDecrStmts + ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"]) + mempty +incrementVar' _ _ _ ATNoop = pure () +incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t +incrementVar' marker inc path (ATCondTag t1 t2) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 +incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) + stmts2 + (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) + stmts3 + stmts1)) +incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 toLinearIdx :: SNat n -> String -> String -> CExpr toLinearIdx SZ _ _ = CELit "0" @@ -921,21 +1278,31 @@ toLinearIdx (SS n) arrvar idxvar = -- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) -- _ +data AllocMethod = Malloc | Calloc + deriving (Show) + -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String -allocArray nameBase rank eltty shsz shape = do +allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" let arrty = STArr rank eltty strname <- emitStruct arrty arrname <- genName' nameBase + shsz <- case mshsz of + Just e -> return e + Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape) + let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8))) + "+" + (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname - [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8))) - "+" - (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])] + [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] + Calloc -> CECall "calloc_instr" [nbytesExpr])] forM_ (zip shape [0::Int ..]) $ \(dim, i) -> emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" return arrname compileShapeQuery :: SNat n -> String -> CExpr @@ -945,20 +1312,17 @@ compileShapeQuery (SS n) var = [("a", compileShapeQuery n var) ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] -compileShapeSize :: SNat n -> String -> CExpr -compileShapeSize SZ _ = CELit "1" -compileShapeSize (SS SZ) var = CELit (var ++ ".b") -compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b")) - -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize SZ _ = CELit "1" -compileArrShapeSize n var = - foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") - | i <- [0 .. fromSNat n - 1]] +compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeComponents :: SNat n -> String -> [CExpr] +compileArrShapeComponents n var = + [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] -compileShapeTupIntoArray :: SNat n -> String -> [CExpr] -compileShapeTupIntoArray = \n var -> map CELit (toList (go n var)) +indexTupleComponents :: SNat n -> String -> [CExpr] +indexTupleComponents = \n var -> map CELit (toList (go n var)) where go :: SNat n -> String -> Bag String go SZ _ = mempty @@ -976,7 +1340,7 @@ shapeTupFromLitVars = \n -> go n . reverse compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do - let unary cop = return @(State CompState) $ CECall cop [e1] + let unary cop = return @CompM $ CECall cop [e1] let binary cop = do name <- genName emit $ SVarDecl True (repSTy (opt1 op)) name e1 @@ -1004,10 +1368,11 @@ compileOpGeneral op e1 = do OLog STF32 -> unary "logf" OLog STF64 -> unary "log" OIDiv _ -> binary "/" + OMod _ -> binary "%" compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr compileOpPair op e1 e2 = do - let binary cop = return @(State CompState) $ CEBinop e1 cop e2 + let binary cop = return @CompM $ CEBinop e1 cop e2 case op of OAdd _ -> binary "+" OMul _ -> binary "*" @@ -1017,6 +1382,7 @@ compileOpPair op e1 e2 = do OAnd -> binary "&&" OOr -> binary "||" OIDiv _ -> binary "/" + OMod _ -> binary "%" _ -> error "compileOpPair: got unary operator" -- | Bool: whether to ensure that the literal itself already has the appropriate type @@ -1031,23 +1397,22 @@ compileScal pedantic typ x = case typ of compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr compileExtremum nameBase opName operator env e = do let STArr (SS n) t = typeOf e - e' <- compile' env e - argname <- genName' (nameBase ++ "arg") - emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e' + argname <- compileAssign (nameBase ++ "arg") env e + + zeroRefcountCheck (typeOf e) opName argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray (nameBase ++ "res") n t (CELit shszname) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) - emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }" + emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" ivar <- genName' "i" jvar <- genName' "j" @@ -1062,99 +1427,107 @@ compileExtremum nameBase opName operator env e = do ] ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)] + incrementVarAlways nameBase Decrement (typeOf e) argname + return (CELit resname) -- | If this returns Nothing, there was nothing to copy because making a simple -- value copy in C already makes it suitable to write to. -copyForWriting :: STy t -> String -> CompM (Maybe CExpr) +copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) copyForWriting topty var = case topty of - STNil -> return Nothing + SMTNil -> return Nothing - STPair a b -> do + SMTPair a b -> do e1 <- copyForWriting a (var ++ ".a") e2 <- copyForWriting b (var ++ ".b") case (e1, e2) of (Nothing, Nothing) -> return Nothing - _ -> return $ Just $ CEStruct (repSTy topty) + _ -> return $ Just $ CEStruct toptyname [("a", fromMaybe (CELit (var++".a")) e1) ,("b", fromMaybe (CELit (var++".b")) e2)] - STEither a b -> do + SMTLEither a b -> do (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") case (e1, e2) of (Nothing, Nothing) -> return Nothing _ -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name - emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) - (BList stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) + emit $ SVarDeclUninit toptyname name + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) + (stmts1 + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) - (BList stmts2 - <> pure (SAsg name (CEStruct (repSTy topty) + (stmts2 + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) return (Just (CELit name)) - STMaybe t -> do + SMTMaybe t -> do (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") case e1 of Nothing -> return Nothing Just e1' -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name - emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) - (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) - (BList stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) + emit $ SVarDeclUninit toptyname name + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) + (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) + (stmts1 + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) return (Just (CELit name)) -- If there are no nested arrays, we know that a refcount of 1 means that the -- whole thing is owned. Nested arrays have their own refcount, so with -- nesting we'd have to check the refcounts of all the nested arrays _too_; - -- at that point we might as well copy the whole thing. Furthermore, no - -- sub-arrays means that the whole thing is flat, and we can just memcpy if - -- necessary. - STArr n t | not (hasArrays t) -> do + -- let's not do that. Furthermore, no sub-arrays means that the whole thing + -- is flat, and we can just memcpy if necessary. + SMTArr n t | not (hasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" - emit $ SVarDeclUninit (repSTy (STArr n t)) name + emit $ SVarDeclUninit toptyname name - emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1")) + when debugShapes $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + emit $ SVerbatim $ + "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ + concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + ");" + + emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) (pure (SAsg name (CELit var))) (let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])]) + ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ - printCExpr 0 databytes ")"]) + printCExpr 0 databytes ");"]) return (Just (CELit name)) - STArr n t -> do + SMTArr n t -> do shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes name <- genName - emit $ SVarDecl False (repSTy (STArr n t)) name - (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])]) + emit $ SVarDecl False toptyname name + (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain dstvar <- genName' "cpydst" - emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) srcvar <- genName' "cpysrc" - emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) ivar <- genName' "i" @@ -1164,18 +1537,82 @@ copyForWriting topty var = case topty of Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - BList cpystmts + cpystmts <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') return (Just (CELit name)) - STScal _ -> return Nothing + SMTScal _ -> return Nothing - STAccum _ -> error "Compile: Nested accumulators not supported" + where + toptyname = repSTy (fromSMTy topty) + +zeroRefcountCheck :: STy t -> String -> String -> CompM () +zeroRefcountCheck toptyp opname topvar = + when emitChecks $ do + mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) + case mstmts of + Nothing -> return () + Just stmts -> forM_ stmts emit + where + -- | If this returns 'Nothing', no statements need to be generated for this type. + go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) + go STNil _ = empty + go (STPair a b) path = do + (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) + return (s1 <> s2) + go (STEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) + go (STMaybe a) path = do + ss <- go a (path++".j") + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty + go (STArr n a) path = do + ivar <- genName' "i" + ss <- go a (path++".buf->xs["++ivar++"]") + shszname <- genName' "shsz" + let s1 = SVerbatim $ + "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ + "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++ + "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }" + let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) + let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss + return (BList [s1, s2, s3]) + go STScal{} _ = empty + go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + + combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) + combine (MaybeT a) (MaybeT b) = MaybeT $ do + x <- a + y <- b + return $ case (x, y) of + (Nothing, Nothing) -> Nothing + (Just x', Nothing) -> Just (x', mempty) + (Nothing, Just y') -> Just (mempty, y') + (Just x', Just y') -> Just (x', y') + +{-# NOINLINE uniqueIdGenRef #-} +uniqueIdGenRef :: IORef Int +uniqueIdGenRef = unsafePerformIO $ newIORef 1 compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id +showPtr :: Ptr a -> String +showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" + -- | Type-restricted. (^) :: Num a => a -> Int -> a (^) = (Prelude.^) + +foldl0' :: (a -> a -> a) -> a -> [a] -> a +foldl0' _ x [] = x +foldl0' f _ l = foldl1' f l |