aboutsummaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs1063
1 files changed, 750 insertions, 313 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index d9cfd95..a5c4fb7 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,13 +1,19 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Compile (compile) where
+import Control.Applicative (empty)
import Control.Monad (forM_, when, replicateM)
import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Maybe
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
import Data.Bifunctor (first)
@@ -16,6 +22,7 @@ import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
import Data.Functor.Product (Product)
+import Data.IORef
import Data.List (foldl1', intersperse, intercalate)
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
@@ -24,52 +31,67 @@ import Data.Set (Set)
import Data.Some
import qualified Data.Vector as V
import Foreign
+import GHC.Exts (int2Word#, addr2Int#)
+import GHC.Num (integerFromWord#)
+import GHC.Ptr (Ptr(..))
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
+import System.IO.Error (mkIOError, userErrorType)
+import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding ((^))
import qualified Prelude
import Array
import AST
-import AST.Pretty (ppTy)
+import AST.Pretty (ppSTy, ppExpr)
+import AST.Sparse.Types (isDense)
import Compile.Exec
import Data
import Interpreter.Rep
-
-
-{-
-:m *Example Compile AST.UnMonoid
-:seti -XOverloadedLabels -XGADTs
-let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i :: Double) in (($ SCons (Value array) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ #x ! pair nil (round_ (#x ! pair nil 3))
-(($ SNil) =<<) $ compile knownEnv $ fromNamed $ body $ build2 5 3 (#i :-> #j :-> 10 * #i + #j)
--}
+import qualified Util.IdGen as IdGen
-- In shape and index arrays, the innermost dimension is on the right (last index).
+-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
+
-debug :: Bool
-debug = toEnum 0
+-- | Print the compiled AST
+debugPrintAST :: Bool; debugPrintAST = toEnum 0
+-- | Print the generated C source
+debugCSource :: Bool; debugCSource = toEnum 0
+-- | Print extra stuff about reference counts of arrays
+debugRefc :: Bool; debugRefc = toEnum 0
+-- | Print some shape-related information
+debugShapes :: Bool; debugShapes = toEnum 0
+-- | Print information on allocation
+debugAllocs :: Bool; debugAllocs = toEnum 0
+-- | Emit extra C code that checks stuff
+emitChecks :: Bool; emitChecks = toEnum 0
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
- let source = compileToString env expr
- when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
- lib <- buildKernel source ["kernel"]
+ codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
- let arg_metrics = reverse (unSList metricsSTy env)
- (arg_offsets, result_offset) = computeStructOffsets arg_metrics
- result_type = typeOf expr
+ let (source, offsets) = compileToString codeID env expr
+ when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
+ when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
+ lib <- buildKernel source "kernel"
+
+ let result_type = typeOf expr
result_size = sizeofSTy result_type
return $ \val -> do
- allocaBytes (result_offset + result_size) $ \ptr -> do
- let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets
+ allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
serialiseArguments args ptr $ do
- callKernelFun "kernel" lib ptr
- deserialise result_type ptr result_offset
+ callKernelFun lib ptr
+ ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
+ when (ok /= 1) $
+ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
+ deserialise result_type ptr (koResultOffset offsets)
where
serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
@@ -184,62 +206,68 @@ printCExpr d = \case
,("/", (7, (7, 8)))
,("%", (7, (7, 8)))]
-repTy :: Ty -> String
-repTy (TScal st) = case st of
- TI32 -> "int32_t"
- TI64 -> "int64_t"
- TF32 -> "float"
- TF64 -> "double"
- TBool -> "uint8_t"
-repTy t = genStructName t
-
repSTy :: STy t -> String
-repSTy = repTy . unSTy
-
-genStructName :: Ty -> String
+repSTy (STScal st) = case st of
+ STI32 -> "int32_t"
+ STI64 -> "int64_t"
+ STF32 -> "float"
+ STF64 -> "double"
+ STBool -> "uint8_t"
+repSTy t = genStructName t
+
+genStructName :: STy t -> String
genStructName = \t -> "ty_" ++ gen t where
-- all tags start with a letter, so the array mangling is unambiguous.
- gen :: Ty -> String
- gen TNil = "n"
- gen (TPair a b) = 'P' : gen a ++ gen b
- gen (TEither a b) = 'E' : gen a ++ gen b
- gen (TMaybe t) = 'M' : gen t
- gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t
- gen (TScal st) = case st of
- TI32 -> "i"
- TI64 -> "j"
- TF32 -> "f"
- TF64 -> "d"
- TBool -> "b"
- gen (TAccum t) = 'C' : gen t
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
-- types in the C translation.
-genStruct :: String -> Ty -> [StructDecl]
+--
+-- For accumulation it is important that for struct representations of monoid
+-- types, the all-zero-bytes value corresponds to the zero value of that type.
+genStruct :: String -> STy t -> [StructDecl]
genStruct name topty = case topty of
- TNil ->
+ STNil ->
[StructDecl name "" com]
- TPair a b ->
- [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com]
- TEither a b -> -- 0 -> l, 1 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com]
- TMaybe t -> -- 0 -> nothing, 1 -> just
- [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com]
- TArr n t ->
+ STPair a b ->
+ [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ STEither a b -> -- 0 -> l, 1 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STMaybe t -> -- 0 -> nothing, 1 -> just
+ [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") ""
+ -- TODO: put shape in the main struct, not the buffer; it's constant, after all
+ -- TODO: no buffer if n = 0
+ [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
,StructDecl name (name ++ "_buf *buf;") com]
- TScal _ ->
+ STScal _ ->
[]
- TAccum t ->
- [StructDecl name (repTy t ++ " ac;") com]
+ STAccum t ->
+ [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
+ ,StructDecl name (name ++ "_buf *buf;") com]
where
- com = ppTy 0 topty
+ com = ppSTy 0 topty
-- State: already-generated (skippable) struct names
-- Writer: the structs in declaration order
-genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()
genStructs ty = do
let name = genStructName ty
seen <- lift $ gets (name `Set.member`)
@@ -251,95 +279,163 @@ genStructs ty = do
-- twice (unnecessary because no recursive types, but y'know)
lift $ modify (Set.insert name)
- case ty of
- TNil -> pure ()
- TPair a b -> genStructs a >> genStructs b
- TEither a b -> genStructs a >> genStructs b
- TMaybe t -> genStructs t
- TArr _ t -> genStructs t
- TScal _ -> pure ()
- TAccum t -> genStructs t
+ () <- case ty of
+ STNil -> pure ()
+ STPair a b -> genStructs a >> genStructs b
+ STEither a b -> genStructs a >> genStructs b
+ STLEither a b -> genStructs a >> genStructs b
+ STMaybe t -> genStructs t
+ STArr _ t -> genStructs t
+ STScal _ -> pure ()
+ STAccum t -> genStructs (fromSMTy t)
tell (BList (genStruct name ty))
-genAllStructs :: Foldable t => t Ty -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty
+genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
+genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty
data CompState = CompState
- { csStructs :: Set Ty
+ { csStructs :: Set (Some STy)
, csTopLevelDecls :: Bag String
, csStmts :: Bag Stmt
, csNextId :: Int }
deriving (Show)
-type CompM a = State CompState a
+newtype CompM a = CompM (State CompState a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runCompM :: CompM a -> (a, CompState)
+runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
-genId :: CompM Int
-genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+class Monad m => MonadNameGen m where genId :: m Int
+instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
+instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
-genName' :: String -> CompM String
+genName' :: MonadNameGen m => String -> m String
+genName' "" = genName
genName' prefix = (prefix ++) . show <$> genId
-genName :: CompM String
+genName :: MonadNameGen m => m String
genName = genName' "x"
+onlyIdGen :: IdGen.IdGen a -> CompM a
+onlyIdGen m = CompM $ do
+ i1 <- gets csNextId
+ let (res, i2) = IdGen.runIdGen' i1 m
+ modify (\s -> s { csNextId = i2 })
+ return res
+
emit :: Stmt -> CompM ()
-emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt }
+emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
-scope :: CompM a -> CompM (a, [Stmt])
+scope :: CompM a -> CompM (a, Bag Stmt)
scope m = do
- stmts <- state $ \s -> (csStmts s, s { csStmts = mempty })
+ stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
res <- m
- innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts })
- return (res, toList innerStmts)
+ innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
+ return (res, innerStmts)
emitStruct :: STy t -> CompM String
-emitStruct ty = do
- let ty' = unSTy ty
- modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) }
- return (genStructName ty')
+emitStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty)
emitTLD :: String -> CompM ()
-emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
+emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
-compileToString :: SList STy env -> Ex env t -> String
-compileToString env expr =
+data KernelOffsets = KernelOffsets
+ { koArgOffsets :: [Int] -- ^ the function arguments
+ , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
+ , koResultOffset :: Int -- ^ the function result
+ }
+
+compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
+compileToString codeID env expr =
let args = nameEnv env
- (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1)
- structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
+ (res, s) = runCompM (compile' args expr)
+ structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
(arg_pairs, arg_metrics) =
unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
(slistZip env args))
- (arg_offsets, result_offset') = computeStructOffsets arg_metrics
- result_offset = align (alignmentSTy (typeOf expr)) result_offset'
- in ($ "") $ compose
+ (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
+ result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
+
+ offsets = KernelOffsets
+ { koArgOffsets = arg_offsets
+ , koOkResOffset = okres_offset
+ , koResultOffset = result_offset }
+ in (,offsets) . ($ "") $ compose
[showString "#include <stdio.h>\n"
,showString "#include <stdint.h>\n"
+ ,showString "#include <stdbool.h>\n"
+ ,showString "#include <inttypes.h>\n"
,showString "#include <stdlib.h>\n"
+ ,showString "#include <string.h>\n"
,showString "#include <math.h>\n\n"
- ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs
+ -- PRint-tag
+ ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
+
+ ,compose [printStructDecl sd . showString "\n" | sd <- structs]
,showString "\n"
+
+ -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
+ ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
+ ,showString " void *ptr = malloc(n);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
+ else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
+ ,showString " return ptr;\n"
+ ,showString "}\n"
+ ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
+ ,showString " void *ptr = calloc(n, 1);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
+ else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
+ ,showString " return ptr;\n"
+ ,showString "}\n"
+ ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void free_instr(void *ptr) {\n"
+ ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
+ else id
+ ,showString " free(ptr);\n"
+ ,showString "}\n\n"
+
,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
+
,showString $
- "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++
- intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
+ "static bool typed_kernel(" ++
+ repSTy (typeOf expr) ++ " *output" ++
+ concatMap (", " ++)
+ (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
") {\n"
- ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s))
- ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n"
+ ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
+ ,showString " *output = " . printCExpr 0 res . showString ";\n"
+ ,showString " return true;\n"
+ ,showString "}\n\n"
+
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"
- ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
- concat (map (\((arg, typ), off, idx) ->
- "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
- ++ (if idx < length arg_pairs - 1 then "," else "")
- ++ " // " ++ arg)
- (zip3 arg_pairs arg_offsets [0::Int ..])) ++
+ ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
+ else id
+ ,showString $ " const bool success = typed_kernel(" ++
+ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
+ concat (map (\((arg, typ), off) ->
+ ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
+ ++ " /* " ++ arg ++ " */")
+ (zip arg_pairs arg_offsets)) ++
"\n );\n"
+ ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
+ else id
,showString "}\n"]
-- | Takes list of metrics (alignment, sizeof).
@@ -363,11 +459,20 @@ serialise topty topval ptr off k =
serialise a x ptr off $
serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
(STEither a _, Left x) -> do
- pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b})
+ pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
serialise a x ptr (off + alignmentSTy topty) k
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
(STMaybe _, Nothing) -> do
pokeByteOff ptr off (0 :: Word8)
k
@@ -377,6 +482,8 @@ serialise topty topval ptr off k =
(STArr n t, Array sh vec) -> do
let eltsz = sizeofSTy t
allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
pokeByteOff ptr off bufptr
pokeShape bufptr 0 n sh
@@ -409,9 +516,16 @@ deserialise topty ptr off =
return (x, y)
STEither a b -> do
tag <- peekByteOff @Word8 ptr off
- if tag == 0 -- alignment of (a + b) is alignment of (union {a b})
+ if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
STMaybe t -> do
tag <- peekByteOff @Word8 ptr off
if tag == 0
@@ -421,6 +535,8 @@ deserialise topty ptr off =
bufptr <- peekByteOff @(Ptr ()) ptr off
sh <- peekShape bufptr 0 n
refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
let off1 = 8 * fromSNat n + 8
eltsz = sizeofSTy t
arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
@@ -454,6 +570,10 @@ metricsSTy (STEither a b) =
let (a1, s1) = metricsSTy a
(a2, s2) = metricsSTy b
in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
@@ -464,7 +584,7 @@ metricsSTy (STScal sty) = case sty of
STF32 -> (4, 4)
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
-metricsSTy (STAccum t) = metricsSTy t
+metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -486,15 +606,13 @@ compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
compile' env = \case
EVar _ t i -> do
let Const var = slistIdx env i
- incrementVarAlways Increment t var
+ incrementVarAlways "var" Increment t var
return $ CELit var
ELet _ rhs body -> do
- e <- compile' env rhs
- var <- genName
- emit $ SVarDecl True (repSTy (typeOf rhs)) var e
+ var <- compileAssign "" env rhs
rete <- compile' (Const var `SCons` env) body
- incrementVarAlways Decrement (typeOf rhs) var
+ incrementVarAlways "let" Decrement (typeOf rhs) var
return rete
EPair _ a b -> do
@@ -506,7 +624,7 @@ compile' env = \case
EFst _ e -> do
let STPair _ t2 = typeOf e
e' <- compile' env e
- case incrementVar Decrement t2 of
+ case incrementVar "fst" Decrement t2 of
Nothing -> return $ CEProj e' "a"
Just f -> do var <- genName
emit $ SVarDecl True (repSTy (typeOf e)) var e'
@@ -516,7 +634,7 @@ compile' env = \case
ESnd _ e -> do
let STPair t1 _ = typeOf e
e' <- compile' env e
- case incrementVar Decrement t1 of
+ case incrementVar "snd" Decrement t1 of
Nothing -> return $ CEProj e' "b"
Just f -> do var <- genName
emit $ SVarDecl True (repSTy (typeOf e)) var e'
@@ -544,8 +662,8 @@ compile' env = \case
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SIf e1
- (BList stmts2 <> pure (SAsg retvar e2))
- (BList stmts3 <> pure (SAsg retvar e3))
+ (stmts2 <> pure (SAsg retvar e2))
+ (stmts3 <> pure (SAsg retvar e3))
return (CELit retvar)
ECase _ e a b -> do
@@ -555,17 +673,17 @@ compile' env = \case
-- I know those are not variable names, but it's fine, probably
(e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a
(e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b
- ((), stmtsRel1) <- scope $ incrementVarAlways Decrement t1 (var ++ ".l")
- ((), stmtsRel2) <- scope $ incrementVarAlways Decrement t2 (var ++ ".r")
+ ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r")
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (BList stmts2
- <> BList stmtsRel1
+ (stmts2
+ <> stmtsRel1
<> pure (SAsg retvar e2))
- (BList stmts3
- <> BList stmtsRel2
+ (stmts3
+ <> stmtsRel2
<> pure (SAsg retvar e3))))
return (CELit retvar)
@@ -584,18 +702,51 @@ compile' env = \case
var <- genName
(e2, stmts2) <- scope $ compile' env a
(e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b
- ((), stmtsRel) <- scope $ incrementVarAlways Decrement t (var ++ ".j")
+ ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j")
retvar <- genName
emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (BList stmts2
+ (stmts2
<> pure (SAsg retvar e2))
- (BList stmts3
- <> BList stmtsRel
+ (stmts3
+ <> stmtsRel
<> pure (SAsg retvar e3))))
return (CELit retvar)
+ ELNil _ t1 t2 -> do
+ name <- emitStruct (STLEither t1 t2)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ ELInl _ t e -> do
+ name <- emitStruct (STLEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
+
+ ELInr _ t e -> do
+ name <- emitStruct (STLEither t (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
+
+ ELCase _ e a b c -> do
+ let STLEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
+ (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
+ ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2 <> pure (SAsg retvar e2))
+ (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
+ (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
+ (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
+ return (CELit retvar)
+
EConstArr _ n t (Array sh vec) -> do
strname <- emitStruct (STArr n (STScal t))
tldname <- genName' "carraybuf"
@@ -608,12 +759,9 @@ compile' env = \case
return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
EBuild _ n esh efun -> do
- shname <- genName' "sh"
- emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh
- shsizename <- genName' "shsz"
- emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname)
+ shname <- compileAssign "sh" env esh
- arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname)
+ arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
idxargname <- genName' "ix"
(funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
@@ -627,40 +775,98 @@ compile' env = \case
| (ivar, dimidx) <- zip ivars [0::Int ..]]
(pure (SVarDecl True (repSTy (typeOf esh)) idxargname
(shapeTupFromLitVars n ivars))
- <> BList funstmts
+ <> funstmts
<> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval))
return (CELit arrname)
- -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c)
+ EFold1Inner _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
- ESum1Inner _ e -> do
- let STArr (SS n) t = typeOf e
- e' <- compile' env e
- argname <- genName' "sumarg"
- emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e'
+ -- let vecwid = case commut of Commut -> 8 :: Int
+ -- Noncommut -> 1
+
+ x0name <- compileAssign "foldx0" env ex0
+ arrname <- compileAssign "foldarr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1i" arrname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
-- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
+
+ resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ -- kvar <- if vecwid > 1 then genName' "k" else return ""
+
+ accvar <- genName' "tot"
+ let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
+ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SAsg accvar funres))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldx0" Decrement t x0name
+ incrementVarAlways "foldarr" Decrement (typeOf earr) arrname
+
+ return (CELit resname)
+
+ ESum1Inner _ e -> do
+ let STArr (SS n) t = typeOf e
+ argname <- compileAssign "sumarg" env e
+
+ zeroRefcountCheck (typeOf e) "sum1i" argname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, like EFold1Inner.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray "sumres" n t (CELit shszname)
- [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
(CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ let vecwid = 8 :: Int
ivar <- genName' "i"
jvar <- genName' "j"
+ kvar <- genName' "k"
accvar <- genName' "tot"
+ let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
-- we have ScalIsNumeric, so it has 0 and (+) in C
- [SVarDecl False (repSTy t) accvar (CELit "0")
- ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];"
- ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)]
+ [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};"
+ ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $
+ pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];"
+ ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))]
+
+ incrementVarAlways "sum" Decrement (typeOf e) argname
return (CELit resname)
@@ -670,25 +876,24 @@ compile' env = \case
strname <- emitStruct typ
name <- genName
emit $ SVarDecl True strname name (CEStruct strname
- [("buf", CECall "malloc" [CELit (show (8 + sizeofSTy (typeOf e)))])])
+ [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])])
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
emit $ SAsg (name ++ ".buf->xs[0]") e'
return (CELit name)
EReplicate1Inner _ elen earg -> do
let STArr n t = typeOf earg
- lenname <- genName' "replen"
- emit . SVarDecl True (repSTy tIx) lenname =<< compile' env elen
- argname <- genName' "reparg"
- emit . SVarDecl True (repSTy (typeOf earg)) argname =<< compile' env earg
+ lenname <- compileAssign "replen" env elen
+ argname <- compileAssign "reparg" env earg
+
+ zeroRefcountCheck (typeOf earg) "replicate1i" argname
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray "rep" (SS n) t
- (CEBinop (CELit shszname) "*" (CELit lenname))
- ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
- ++ [CELit lenname])
+ resname <- allocArray "repl1i" Malloc "rep" (SS n) t
+ (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
+ (compileArrShapeComponents n argname ++ [CELit lenname])
ivar <- genName' "i"
jvar <- genName' "j"
@@ -697,6 +902,8 @@ compile' env = \case
pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]")
(CELit (argname ++ ".buf->xs[" ++ ivar ++ "]"))
+ incrementVarAlways "repl1i" Decrement (typeOf earg) argname
+
return (CELit resname)
EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e
@@ -707,37 +914,48 @@ compile' env = \case
EIdx0 _ e -> do
let STArr _ t = typeOf e
- e' <- compile' env e
- arrname <- genName
- emit $ SVarDecl True (repSTy (STArr SZ t)) arrname e'
+ arrname <- compileAssign "" env e
+ zeroRefcountCheck (typeOf e) "idx0" arrname
name <- genName
emit $ SVarDecl True (repSTy t) name
(CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
- incrementVarAlways Decrement (STArr SZ t) arrname
+ incrementVarAlways "idx0" Decrement (STArr SZ t) arrname
return (CELit name)
-- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
EIdx _ earr eidx -> do
let STArr n t = typeOf earr
- arrname <- genName' "ixarr"
- idxname <- genName' "ixix"
- emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr
- when (fromSNat n > 0) $ emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx
+ arrname <- compileAssign "ixarr" env earr
+ zeroRefcountCheck (typeOf earr) "idx" arrname
+ idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
+ then compileAssign "ixix" env eidx
+ else return "" -- won't be used in this case
+
+ when emitChecks $
+ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
+ emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
+ arrname ++ ".buf); return false;")
+ mempty
+
resname <- genName' "ixres"
emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname))
- incrementVarAlways Decrement (STArr n t) arrname
+ incrementVarAlways "idxelt" Increment t resname
+ incrementVarAlways "idx" Decrement (STArr n t) arrname
return (CELit resname)
EShape _ e -> do
let STArr n _ = typeOf e
t = tTup (sreplicate n tIx)
_ <- emitStruct t
- name <- genName
- emit . SVarDecl True (repSTy (typeOf e)) name =<< compile' env e
+ name <- compileAssign "" env e
+ zeroRefcountCheck (typeOf e) "shape" name
resname <- genName
emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
- incrementVarAlways Decrement (typeOf e) name
+ incrementVarAlways "shape" Decrement (typeOf e) name
return (CELit resname)
EOp _ op (EPair _ e1 e2) -> do
@@ -749,123 +967,224 @@ compile' env = \case
e' <- compile' env e
compileOpGeneral op e'
- ECustom _ t1 t2 _ earg _ _ e1 e2 -> do
- e1' <- compile' env e1
- name1 <- genName
- emit $ SVarDecl True (repSTy t1) name1 e1'
- e2' <- compile' env e2
- name2 <- genName
- emit $ SVarDecl True (repSTy t2) name2 e2'
- compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ ECustom _ _ _ _ earg _ _ e1 e2 -> do
+ name1 <- compileAssign "" env e1
+ name2 <- compileAssign "" env e2
+ case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of
+ (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ (mfun1, mfun2) -> do
+ name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ maybe (return ()) ($ name1) mfun1
+ maybe (return ()) ($ name2) mfun2
+ return (CELit name)
+
+ ERecompute _ e -> compile' env e
EWith _ t e1 e2 -> do
- e1' <- compile' env e1
- name1 <- genName
- emit $ SVarDecl True (repSTy (typeOf e1)) name1 e1'
+ actyname <- emitStruct (STAccum t)
+ name1 <- compileAssign "" env e1
+
+ zeroRefcountCheck (typeOf e1) "with" name1
+ emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
mcopy <- copyForWriting t name1
accname <- genName' "accum"
- emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy)
+ emit $ SVarDecl False actyname accname
+ (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
+ emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
+ emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
e2' <- compile' (Const accname `SCons` env) e2
- return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)]
-
- EAccum _ t prj eidx eval eacc -> do
- eidx' <- compile' env eidx
- nameidx <- genName
- emit $ SVarDecl True (repSTy (typeOf eidx)) nameidx eidx'
-
- eval' <- compile' env eval
- nameval <- genName
- emit $ SVarDecl True (repSTy (typeOf eval)) nameval eval'
-
- eacc' <- compile' env eacc
- nameacc <- genName
- emit $ SVarDecl True (repSTy (typeOf eacc)) nameacc eacc'
-
- let accumRef :: STy a -> SAcPrj p a b -> String -> String -> String
- accumRef _ SAPHere v _ = v
- accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i
- accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i
- accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i
- accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i
- accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i
- accumRef (STArr n t') (SAPArrIdx prj' _) v i =
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b")
-
- let add :: STy a -> String -> String -> CompM ()
- add STNil _ _ = return ()
- add (STPair t1 t2) d s = do
+ resname <- genName' "acret"
+ emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
+ emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
+
+ rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
+ return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
+
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
+ let -- Add a value (s) into an existing accumulation value (d). If a sparse
+ -- component of d is encountered, s is copied there.
+ add :: SMTy a -> String -> String -> CompM ()
+ add SMTNil _ _ = return ()
+ add (SMTPair t1 t2) d s = do
add t1 (d++".a") (s++".a")
add t2 (d++".b") (s++".b")
- add (STEither t1 t2) d s = do
+ add (SMTLEither t1 t2) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
- emit $ SAsg (d++".tag") (CELit (s++".tag"))
- emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0"))
- (BList stmts1) (BList stmts2)
- add (STMaybe t1) d s = do
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ ((if emitChecks
+ then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
+ "&&"
+ (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
+ "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
+ "return false;")
+ mempty)
+ else mempty)
+ -- note: s may have tag 0
+ <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ stmts1
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
+ stmts2 mempty))))
+ add (SMTMaybe t1) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
- emit $ SAsg (d++".tag") (CELit (s++".tag"))
- emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- (BList stmts1) mempty
- add (STArr n t1) d s = do
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
+ add (SMTArr n t1) d s = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ [0 .. fromSNat n - 1] $ \j -> do
+ emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
+ "!="
+ (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
+ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
+ d ++ ".buf" ++
+ concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ ", " ++ s ++ ".buf" ++
+ concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ "); " ++
+ "return false;")
+ mempty
+
shsizename <- genName' "acshsz"
- emit $ SVarDecl True "size_t" shsizename (compileShapeSize n (s++".a.b"))
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
ivar <- genName' "i"
((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $
- BList stmts1
- add (STScal sty) d s = case sty of
- STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
- STBool -> error "Compile: accumulator add on booleans"
- add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
-
- let dest = accumRef t prj (nameacc++".ac") nameidx
- add (typeOf eval) dest nameval
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+ stmts1
+ add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
+
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
+
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
+
+ emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
+ emit $ SVerbatim $ "// compile EAccum end"
+
+ incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
return $ CEStruct (repSTy STNil) []
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
EError _ t s -> do
let padleft len c s' = replicate (len - length s) c ++ s'
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
| ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
| otherwise -> [c]
- emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ escape s ++ "); exit(1);"
+ emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
case t of
STScal _ -> return (CELit "0")
_ -> do
name <- emitStruct t
return $ CEStruct name []
- EZero{} -> error "Compile: monoid operations should have been eliminated"
- EPlus{} -> error "Compile: monoid operations should have been eliminated"
- EOneHot{} -> error "Compile: monoid operations should have been eliminated"
+ EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
- EFold1Inner{} -> error "Compile: not implemented: EFold1Inner"
EIdx1{} -> error "Compile: not implemented: EIdx1"
+compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
+compileAssign prefix env e = do
+ e' <- compile' env e
+ case e' of
+ CELit name -> return name
+ _ -> do
+ name <- genName' prefix
+ emit $ SVarDecl True (repSTy (typeOf e)) name e'
+ return name
+
data Increment = Increment | Decrement
deriving (Show)
-- | Increment reference counts in the components of the given variable.
-incrementVar :: Increment -> STy a -> Maybe (String -> CompM ())
-incrementVar inc ty =
+incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ())
+incrementVar marker inc ty =
let tree = makeArrayTree ty
in case tree of ATNoop -> Nothing
- _ -> Just $ \var -> incrementVar' inc var tree
+ _ -> Just $ \var -> incrementVar' marker inc var tree
-incrementVarAlways :: Increment -> STy a -> String -> CompM ()
-incrementVarAlways inc ty var = maybe (pure ()) ($ var) (incrementVar inc ty)
+incrementVarAlways :: String -> Increment -> STy a -> String -> CompM ()
+incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty)
-data ArrayTree = ATArray -- ^ we've arrived at an array we need to decrement the refcount of
+data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array)
| ATNoop -- ^ don't do anything here
| ATProj String ArrayTree -- ^ descend one field deeper
| ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
+ | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
| ATBoth ArrayTree ArrayTree -- ^ do both these paths
smartATProj :: String -> ArrayTree -> ArrayTree
@@ -876,6 +1195,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
smartATCondTag ATNoop ATNoop = ATNoop
smartATCondTag t t' = ATCondTag t t'
+smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
+smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
+smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
+
smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
smartATBoth ATNoop t = t
smartATBoth t ATNoop = t
@@ -887,24 +1210,58 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
-makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t)
-makeArrayTree (STArr _ _) = ATArray
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
+makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
+makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
-incrementVar' :: Increment -> String -> ArrayTree -> CompM ()
-incrementVar' inc path ATArray =
+incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
+incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
case inc of
- Increment -> emit $ SVerbatim (path ++ ".buf->refc++;")
- Decrement ->
- emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);"
-incrementVar' _ _ ATNoop = pure ()
-incrementVar' inc path (ATProj field t) = incrementVar' inc (path ++ "." ++ field) t
-incrementVar' inc path (ATCondTag t1 t2) = do
- ((), stmts1) <- scope $ incrementVar' inc path t1
- ((), stmts2) <- scope $ incrementVar' inc path t2
- emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2)
-incrementVar' inc path (ATBoth t1 t2) = incrementVar' inc path t1 >> incrementVar' inc path t2
+ Increment -> do
+ emit $ SVerbatim (path ++ ".buf->refc++;")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
+ Decrement -> do
+ case incrementVar (marker++".elt") Decrement eltty of
+ Nothing ->
+ if debugRefc
+ then do
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
+ else do
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
+ Just f -> do
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ shszvar <- genName' "frshsz"
+ ivar <- genName' "i"
+ ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
+ emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0"))
+ (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path)
+ ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $
+ eltDecrStmts
+ ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"])
+ mempty
+incrementVar' _ _ _ ATNoop = pure ()
+incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t
+incrementVar' marker inc path (ATCondTag t1 t2) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
+incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
+ stmts2
+ (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
+ stmts3
+ stmts1))
+incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
toLinearIdx :: SNat n -> String -> String -> CExpr
toLinearIdx SZ _ _ = CELit "0"
@@ -921,21 +1278,31 @@ toLinearIdx (SS n) arrvar idxvar =
-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
-- _
+data AllocMethod = Malloc | Calloc
+ deriving (Show)
+
-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String
-allocArray nameBase rank eltty shsz shape = do
+allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray marker method nameBase rank eltty mshsz shape = do
when (length shape /= fromSNat rank) $
error "allocArray: shape does not match rank"
let arrty = STArr rank eltty
strname <- emitStruct arrty
arrname <- genName' nameBase
+ shsz <- case mshsz of
+ Just e -> return e
+ Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
+ let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
+ "+"
+ (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
emit $ SVarDecl True strname arrname $ CEStruct strname
- [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8)))
- "+"
- (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])]
+ [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
+ Calloc -> CECall "calloc_instr" [nbytesExpr])]
forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
return arrname
compileShapeQuery :: SNat n -> String -> CExpr
@@ -945,20 +1312,17 @@ compileShapeQuery (SS n) var =
[("a", compileShapeQuery n var)
,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
-compileShapeSize :: SNat n -> String -> CExpr
-compileShapeSize SZ _ = CELit "1"
-compileShapeSize (SS SZ) var = CELit (var ++ ".b")
-compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b"))
-
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize SZ _ = CELit "1"
-compileArrShapeSize n var =
- foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]")
- | i <- [0 .. fromSNat n - 1]]
+compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeComponents :: SNat n -> String -> [CExpr]
+compileArrShapeComponents n var =
+ [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
-compileShapeTupIntoArray :: SNat n -> String -> [CExpr]
-compileShapeTupIntoArray = \n var -> map CELit (toList (go n var))
+indexTupleComponents :: SNat n -> String -> [CExpr]
+indexTupleComponents = \n var -> map CELit (toList (go n var))
where
go :: SNat n -> String -> Bag String
go SZ _ = mempty
@@ -976,7 +1340,7 @@ shapeTupFromLitVars = \n -> go n . reverse
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
- let unary cop = return @(State CompState) $ CECall cop [e1]
+ let unary cop = return @CompM $ CECall cop [e1]
let binary cop = do
name <- genName
emit $ SVarDecl True (repSTy (opt1 op)) name e1
@@ -1004,10 +1368,11 @@ compileOpGeneral op e1 = do
OLog STF32 -> unary "logf"
OLog STF64 -> unary "log"
OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
compileOpPair op e1 e2 = do
- let binary cop = return @(State CompState) $ CEBinop e1 cop e2
+ let binary cop = return @CompM $ CEBinop e1 cop e2
case op of
OAdd _ -> binary "+"
OMul _ -> binary "*"
@@ -1017,6 +1382,7 @@ compileOpPair op e1 e2 = do
OAnd -> binary "&&"
OOr -> binary "||"
OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
_ -> error "compileOpPair: got unary operator"
-- | Bool: whether to ensure that the literal itself already has the appropriate type
@@ -1031,23 +1397,22 @@ compileScal pedantic typ x = case typ of
compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr
compileExtremum nameBase opName operator env e = do
let STArr (SS n) t = typeOf e
- e' <- compile' env e
- argname <- genName' (nameBase ++ "arg")
- emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e'
+ argname <- compileAssign (nameBase ++ "arg") env e
+
+ zeroRefcountCheck (typeOf e) opName argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
-- unexpected. But it's exactly what we want, so we do it anyway.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
- resname <- allocArray (nameBase ++ "res") n t (CELit shszname)
- [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+ resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
lenname <- genName' "n"
emit $ SVarDecl True (repSTy tIx) lenname
(CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
- emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }"
+ emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
ivar <- genName' "i"
jvar <- genName' "j"
@@ -1062,99 +1427,107 @@ compileExtremum nameBase opName operator env e = do
]
,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)]
+ incrementVarAlways nameBase Decrement (typeOf e) argname
+
return (CELit resname)
-- | If this returns Nothing, there was nothing to copy because making a simple
-- value copy in C already makes it suitable to write to.
-copyForWriting :: STy t -> String -> CompM (Maybe CExpr)
+copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
copyForWriting topty var = case topty of
- STNil -> return Nothing
+ SMTNil -> return Nothing
- STPair a b -> do
+ SMTPair a b -> do
e1 <- copyForWriting a (var ++ ".a")
e2 <- copyForWriting b (var ++ ".b")
case (e1, e2) of
(Nothing, Nothing) -> return Nothing
- _ -> return $ Just $ CEStruct (repSTy topty)
+ _ -> return $ Just $ CEStruct toptyname
[("a", fromMaybe (CELit (var++".a")) e1)
,("b", fromMaybe (CELit (var++".b")) e2)]
- STEither a b -> do
+ SMTLEither a b -> do
(e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
(e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
case (e1, e2) of
(Nothing, Nothing) -> return Nothing
_ -> do
name <- genName
- emit $ SVarDeclUninit (repSTy topty) name
- emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
- (BList stmts1
- <> pure (SAsg name (CEStruct (repSTy topty)
+ emit $ SVarDeclUninit toptyname name
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
+ (stmts1
+ <> pure (SAsg name (CEStruct toptyname
[("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
- (BList stmts2
- <> pure (SAsg name (CEStruct (repSTy topty)
+ (stmts2
+ <> pure (SAsg name (CEStruct toptyname
[("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
return (Just (CELit name))
- STMaybe t -> do
+ SMTMaybe t -> do
(e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
case e1 of
Nothing -> return Nothing
Just e1' -> do
name <- genName
- emit $ SVarDeclUninit (repSTy topty) name
- emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
- (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))
- (BList stmts1
- <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')])))
+ emit $ SVarDeclUninit toptyname name
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
+ (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
+ (stmts1
+ <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
return (Just (CELit name))
-- If there are no nested arrays, we know that a refcount of 1 means that the
-- whole thing is owned. Nested arrays have their own refcount, so with
-- nesting we'd have to check the refcounts of all the nested arrays _too_;
- -- at that point we might as well copy the whole thing. Furthermore, no
- -- sub-arrays means that the whole thing is flat, and we can just memcpy if
- -- necessary.
- STArr n t | not (hasArrays t) -> do
+ -- let's not do that. Furthermore, no sub-arrays means that the whole thing
+ -- is flat, and we can just memcpy if necessary.
+ SMTArr n t | not (hasArrays (fromSMTy t)) -> do
name <- genName
shszname <- genName' "shsz"
- emit $ SVarDeclUninit (repSTy (STArr n t)) name
+ emit $ SVarDeclUninit toptyname name
- emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1"))
+ when debugShapes $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ emit $ SVerbatim $
+ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
+ concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ ");"
+
+ emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
(pure (SAsg name (CELit var)))
(let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
in BList
[SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
- ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])])
+ ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
show shbytes ++ ");"
,SAsg (name ++ ".buf->refc") (CELit "1")
,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
- printCExpr 0 databytes ")"])
+ printCExpr 0 databytes ");"])
return (Just (CELit name))
- STArr n t -> do
+ SMTArr n t -> do
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
name <- genName
- emit $ SVarDecl False (repSTy (STArr n t)) name
- (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])])
+ emit $ SVarDecl False toptyname name
+ (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
show shbytes ++ ");"
emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-- put the arrays in variables to cut short the not-quite-var chain
dstvar <- genName' "cpydst"
- emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs"))
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
srcvar <- genName' "cpysrc"
- emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs"))
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
ivar <- genName' "i"
@@ -1164,18 +1537,82 @@ copyForWriting topty var = case topty of
Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug"
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- BList cpystmts
+ cpystmts
<> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye')
return (Just (CELit name))
- STScal _ -> return Nothing
+ SMTScal _ -> return Nothing
- STAccum _ -> error "Compile: Nested accumulators not supported"
+ where
+ toptyname = repSTy (fromSMTy topty)
+
+zeroRefcountCheck :: STy t -> String -> String -> CompM ()
+zeroRefcountCheck toptyp opname topvar =
+ when emitChecks $ do
+ mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
+ case mstmts of
+ Nothing -> return ()
+ Just stmts -> forM_ stmts emit
+ where
+ -- | If this returns 'Nothing', no statements need to be generated for this type.
+ go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
+ go STNil _ = empty
+ go (STPair a b) path = do
+ (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
+ return (s1 <> s2)
+ go (STEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
+ go (STMaybe a) path = do
+ ss <- go a (path++".j")
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
+ go (STArr n a) path = do
+ ivar <- genName' "i"
+ ss <- go a (path++".buf->xs["++ivar++"]")
+ shszname <- genName' "shsz"
+ let s1 = SVerbatim $
+ "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
+ "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
+ let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
+ let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
+ return (BList [s1, s2, s3])
+ go STScal{} _ = empty
+ go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
+
+ combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
+ combine (MaybeT a) (MaybeT b) = MaybeT $ do
+ x <- a
+ y <- b
+ return $ case (x, y) of
+ (Nothing, Nothing) -> Nothing
+ (Just x', Nothing) -> Just (x', mempty)
+ (Nothing, Just y') -> Just (mempty, y')
+ (Just x', Just y') -> Just (x', y')
+
+{-# NOINLINE uniqueIdGenRef #-}
+uniqueIdGenRef :: IORef Int
+uniqueIdGenRef = unsafePerformIO $ newIORef 1
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
+showPtr :: Ptr a -> String
+showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) ""
+
-- | Type-restricted.
(^) :: Num a => a -> Int -> a
(^) = (Prelude.^)
+
+foldl0' :: (a -> a -> a) -> a -> [a] -> a
+foldl0' _ x [] = x
+foldl0' f _ l = foldl1' f l