aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Compile.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Compile.hs')
-rw-r--r--src/CHAD/Compile.hs1796
1 files changed, 1796 insertions, 0 deletions
diff --git a/src/CHAD/Compile.hs b/src/CHAD/Compile.hs
new file mode 100644
index 0000000..5b71651
--- /dev/null
+++ b/src/CHAD/Compile.hs
@@ -0,0 +1,1796 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+module CHAD.Compile (compile, compileStderr) where
+
+import Control.Applicative (empty)
+import Control.Monad (forM_, when, replicateM)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Maybe
+import Control.Monad.Trans.State.Strict
+import Control.Monad.Trans.Writer.CPS
+import Data.Bifunctor (first)
+import Data.Char (ord)
+import Data.Foldable (toList)
+import Data.Functor.Const
+import qualified Data.Functor.Product as Product
+import Data.Functor.Product (Product)
+import Data.IORef
+import Data.List (foldl1', intersperse, intercalate)
+import qualified Data.Map.Strict as Map
+import Data.Maybe (fromMaybe)
+import qualified Data.Set as Set
+import Data.Set (Set)
+import Data.Some
+import qualified Data.Vector as V
+import Foreign
+import GHC.Exts (int2Word#, addr2Int#)
+import GHC.Num (integerFromWord#)
+import GHC.Ptr (Ptr(..))
+import GHC.Stack (HasCallStack)
+import Numeric (showHex)
+import System.IO (hPutStrLn, stderr)
+import System.IO.Error (mkIOError, userErrorType)
+import System.IO.Unsafe (unsafePerformIO)
+
+import Prelude hiding ((^))
+import qualified Prelude
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Pretty (ppSTy, ppExpr)
+import CHAD.AST.Sparse.Types (isDense)
+import CHAD.Compile.Exec
+import CHAD.Data
+import CHAD.Interpreter.Rep
+import qualified CHAD.Util.IdGen as IdGen
+
+
+-- In shape and index arrays, the innermost dimension is on the right (last index).
+
+-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
+
+
+-- | Print the compiled AST
+debugPrintAST :: Bool; debugPrintAST = toEnum 0
+-- | Print the generated C source
+debugCSource :: Bool; debugCSource = toEnum 0
+-- | Print extra stuff about reference counts of arrays
+debugRefc :: Bool; debugRefc = toEnum 0
+-- | Print some shape-related information
+debugShapes :: Bool; debugShapes = toEnum 0
+-- | Print information on allocation
+debugAllocs :: Bool; debugAllocs = toEnum 0
+-- | Emit extra C code that checks stuff
+emitChecks :: Bool; emitChecks = toEnum 0
+
+-- | Returns compiled function plus compilation output (warnings)
+compile :: SList STy env -> Ex env t
+ -> IO (SList Value env -> IO (Rep t), String)
+compile = \env expr -> do
+ codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
+
+ let (source, offsets) = compileToString codeID env expr
+ when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
+ when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
+ (lib, compileOutput) <- buildKernel source "kernel"
+
+ let result_type = typeOf expr
+ result_size = sizeofSTy result_type
+
+ let function val = do
+ allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
+ serialiseArguments args ptr $ do
+ callKernelFun lib ptr
+ ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
+ when (ok /= 1) $
+ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
+ deserialise result_type ptr (koResultOffset offsets)
+ return (function, compileOutput)
+ where
+ serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
+ serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
+ serialise t arg ptr off $
+ serialiseArguments args ptr k
+ serialiseArguments _ _ k = k
+
+-- | 'compile', but writes any produced C compiler output to stderr.
+compileStderr :: SList STy env -> Ex env t
+ -> IO (SList Value env -> IO (Rep t))
+compileStderr env expr = do
+ (fun, output) <- compile env expr
+ when (not (null output)) $
+ hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>"
+ return fun
+
+
+data StructDecl = StructDecl
+ String -- ^ name
+ String -- ^ contents
+ String -- ^ comment
+ deriving (Show)
+
+data Stmt
+ = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side
+ | SVarDeclUninit String String -- ^ type, variable name (no initialiser)
+ | SAsg String CExpr -- ^ variable name, right-hand side
+ | SBlock (Bag Stmt)
+ | SIf CExpr (Bag Stmt) (Bag Stmt)
+ | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>}
+ | SVerbatim String -- ^ no implicit ';', just printed as-is
+ deriving (Show)
+
+data CExpr
+ = CELit String -- ^ inserted as-is, assumed no parentheses needed
+ | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
+ | CEProj CExpr String -- ^ field projection: expr.field
+ | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field
+ | CEAddrOf CExpr -- ^ &expr
+ | CEIndex CExpr CExpr -- ^ expr[expr]
+ | CECall String [CExpr] -- ^ function(arg1, ..., argn)
+ | CEBinop CExpr String CExpr -- ^ expr + expr
+ | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
+ | CECast String CExpr -- ^ (<type>)<expr>
+ deriving (Show)
+
+printStructDecl :: StructDecl -> ShowS
+printStructDecl (StructDecl name contents comment) =
+ showString "typedef struct { " . showString contents . showString " } " . showString name
+ . showString ";" . (if null comment then id else showString (" // " ++ comment))
+
+printStmt :: Int -> Stmt -> ShowS
+printStmt indent = \case
+ SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
+ SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
+ SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
+ SBlock stmts
+ | null stmts -> showString "{}"
+ | otherwise ->
+ showString "{"
+ . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
+ . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
+ SIf cond b1 b2 ->
+ showString "if (" . printCExpr 0 cond . showString ") "
+ . printStmt indent (SBlock b1)
+ . (if null b2 then id else showString " else " . printStmt indent (SBlock b2))
+ SLoop typ name e1 e2 stmts ->
+ showString ("for (" ++ typ ++ " " ++ name ++ " = ")
+ . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2
+ . showString ("; " ++ name ++ "++) ")
+ . printStmt indent (SBlock stmts)
+ SVerbatim s -> showString s
+
+-- d values:
+-- * 0: top level
+-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)
+-- * 2-...: various operators (see precTable)
+-- * 80: address-of operator (&)
+-- * 98: inside unknown operator
+-- * 99: left of a field projection
+-- Unlisted operators are conservatively written with full parentheses.
+printCExpr :: Int -> CExpr -> ShowS
+printCExpr d = \case
+ CELit s -> showString s
+ CEStruct name pairs ->
+ showParen (d >= 99) $
+ showString ("(" ++ name ++ "){")
+ . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e
+ | (n, e) <- pairs])
+ . showString "}"
+ CEProj e name -> printCExpr 99 e . showString ("." ++ name)
+ CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name)
+ CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
+ CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"
+ CECall n es ->
+ showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
+ CEBinop e1 n e2 ->
+ let mprec = Map.lookup n precTable
+ p = maybe (-1) fst mprec -- precedence of this operator
+ (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments
+ in showParen (d > p) $
+ printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2
+ CEIf e1 e2 e3 ->
+ showParen (d > 0) $
+ printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3
+ CECast typ e ->
+ showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e
+ where
+ precTable = Map.fromList
+ [("||", (2, (2, 2)))
+ ,("&&", (3, (3, 3)))
+ ,("==", (4, (5, 5)))
+ ,("!=", (4, (5, 5)))
+ ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop
+ ,(">", (5, (6, 6)))
+ ,("<=", (5, (6, 6)))
+ ,(">=", (5, (6, 6)))
+ ,("+", (6, (6, 7)))
+ ,("-", (6, (6, 7)))
+ ,("*", (7, (7, 8)))
+ ,("/", (7, (7, 8)))
+ ,("%", (7, (7, 8)))]
+
+repSTy :: STy t -> String
+repSTy (STScal st) = case st of
+ STI32 -> "int32_t"
+ STI64 -> "int64_t"
+ STF32 -> "float"
+ STF64 -> "double"
+ STBool -> "uint8_t"
+repSTy t = genStructName t
+
+genStructName, genArrBufStructName :: STy t -> String
+(genStructName, genArrBufStructName) =
+ (\t -> "ty_" ++ gen t
+ ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension
+ t -> error $ "genArrBufStructName: not an array type: " ++ show t)
+ where
+ -- all tags start with a letter, so the array mangling is unambiguous.
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen (fromSMTy t)
+
+-- The subtrees contain structs used in the bodies of the structs in this node.
+data StructTree = TreeNode [StructDecl] [StructTree]
+ deriving (Show)
+
+-- | This function generates the actual struct declarations for each of the
+-- types in our language. It thus implicitly "documents" the layout of the
+-- types in the C translation.
+--
+-- For accumulation it is important that for struct representations of monoid
+-- types, the all-zero-bytes value corresponds to the zero value of that type.
+buildStructTree :: STy t -> StructTree
+buildStructTree topty = case topty of
+ STNil ->
+ TreeNode [StructDecl name "" com] []
+ STPair a b ->
+ TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ [buildStructTree a, buildStructTree b]
+ STEither a b -> -- 0 -> l, 1 -> r
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ [buildStructTree a, buildStructTree b]
+ STMaybe t -> -- 0 -> nothing, 1 -> just
+ TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ [buildStructTree t]
+ STArr n t ->
+ -- The buffer is trailed by a VLA for the actual array data.
+ -- TODO: no buffer if n = 0
+ TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") ""
+ ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com]
+ [buildStructTree t]
+ STScal _ ->
+ TreeNode [] []
+ STAccum t ->
+ TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
+ ,StructDecl name (name ++ "_buf *buf;") com]
+ [buildStructTree (fromSMTy t)]
+ where
+ name = genStructName topty
+ com = ppSTy 0 topty
+
+-- State: already-generated (skippable) struct names
+-- Writer: the structs in declaration order
+genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructTreeW (TreeNode these deps) = do
+ seen <- lift get
+ case filter ((`Set.notMember` seen) . nameOf) these of
+ [] -> pure ()
+ structs -> do
+ lift $ modify (Set.fromList (map nameOf structs) <>)
+ mapM_ genStructTreeW deps
+ tell (BList structs)
+ where
+ nameOf (StructDecl name _ _) = name
+
+genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
+genAllStructs tys =
+ let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys
+ in toList (evalState (execWriterT m) mempty)
+
+data CompState = CompState
+ { csStructs :: Set (Some STy)
+ , csTopLevelDecls :: Bag String
+ , csStmts :: Bag Stmt
+ , csNextId :: Int }
+ deriving (Show)
+
+newtype CompM a = CompM (State CompState a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runCompM :: CompM a -> (a, CompState)
+runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
+
+class Monad m => MonadNameGen m where genId :: m Int
+instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
+instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
+
+genName' :: MonadNameGen m => String -> m String
+genName' "" = genName
+genName' prefix = (prefix ++) . show <$> genId
+
+genName :: MonadNameGen m => m String
+genName = genName' "x"
+
+onlyIdGen :: IdGen.IdGen a -> CompM a
+onlyIdGen m = CompM $ do
+ i1 <- gets csNextId
+ let (res, i2) = IdGen.runIdGen' i1 m
+ modify (\s -> s { csNextId = i2 })
+ return res
+
+emit :: Stmt -> CompM ()
+emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
+
+scope :: CompM a -> CompM (a, Bag Stmt)
+scope m = do
+ stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
+ res <- m
+ innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
+ return (res, innerStmts)
+
+emitStruct :: STy t -> CompM String
+emitStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty)
+
+-- | Also returns the name of the array buffer struct
+emitArrStruct :: STy t -> CompM (String, String)
+emitArrStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
+ return (genStructName ty, genArrBufStructName ty)
+
+emitTLD :: String -> CompM ()
+emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
+
+nameEnv :: SList f env -> SList (Const String) env
+nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
+
+data KernelOffsets = KernelOffsets
+ { koArgOffsets :: [Int] -- ^ the function arguments
+ , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
+ , koResultOffset :: Int -- ^ the function result
+ }
+
+compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
+compileToString codeID env expr =
+ let args = nameEnv env
+ (res, s) = runCompM (compile' args expr)
+ structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
+
+ (arg_pairs, arg_metrics) =
+ unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
+ (slistZip env args))
+ (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
+ result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
+
+ offsets = KernelOffsets
+ { koArgOffsets = arg_offsets
+ , koOkResOffset = okres_offset
+ , koResultOffset = result_offset }
+ in (,offsets) . ($ "") $ compose
+ [showString "#include <stdio.h>\n"
+ ,showString "#include <stdint.h>\n"
+ ,showString "#include <stdbool.h>\n"
+ ,showString "#include <inttypes.h>\n"
+ ,showString "#include <stdlib.h>\n"
+ ,showString "#include <string.h>\n"
+ ,showString "#include <math.h>\n\n"
+ -- PRint-tag
+ ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
+
+ ,compose [printStructDecl sd . showString "\n" | sd <- structs]
+ ,showString "\n"
+
+ -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
+ ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
+ ,showString " void *ptr = malloc(n);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
+ else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
+ ,showString " return ptr;\n"
+ ,showString "}\n"
+ ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
+ ,showString " void *ptr = calloc(n, 1);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
+ else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
+ else id
+ ,showString " return ptr;\n"
+ ,showString "}\n"
+ ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void free_instr(void *ptr) {\n"
+ ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
+ else id
+ ,showString " free(ptr);\n"
+ ,showString "}\n\n"
+
+ ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
+
+ ,showString $
+ "static bool typed_kernel(" ++
+ repSTy (typeOf expr) ++ " *output" ++
+ concatMap (", " ++)
+ (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
+ ") {\n"
+ ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
+ ,showString " *output = " . printCExpr 0 res . showString ";\n"
+ ,showString " return true;\n"
+ ,showString "}\n\n"
+
+ ,showString "void kernel(void *data) {\n"
+ -- Some code here assumes that we're on a 64-bit system, so let's check that
+ ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
+ else id
+ ,showString $ " const bool success = typed_kernel(" ++
+ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
+ concat (map (\((arg, typ), off) ->
+ ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
+ ++ " /* " ++ arg ++ " */")
+ (zip arg_pairs arg_offsets)) ++
+ "\n );\n"
+ ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
+ else id
+ ,showString "}\n"]
+
+-- | Takes list of metrics (alignment, sizeof).
+-- Returns (offsets, size of struct).
+computeStructOffsets :: [(Int, Int)] -> ([Int], Int)
+computeStructOffsets = go 0 0
+ where
+ go off maxal [(al, sz)] =
+ ([off], align (max maxal al) (off + sz))
+ go off maxal ((al, sz) : pairs@((al2,_):_)) =
+ first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs
+ go _ _ [] = ([], 0)
+
+-- | Assumes that this is called at the correct alignment.
+serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r
+serialise topty topval ptr off k =
+ -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
+ case (topty, topval) of
+ (STNil, ()) -> k
+ (STPair a b, (x, y)) ->
+ serialise a x ptr off $
+ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
+ (STEither a _, Left x) -> do
+ pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STEither _ b, Right y) -> do
+ pokeByteOff ptr off (1 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
+ (STMaybe _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STMaybe t, Just x) -> do
+ pokeByteOff ptr off (1 :: Word8)
+ serialise t x ptr (off + alignmentSTy t) k
+ (STArr n t, Array sh vec) -> do
+ let eltsz = sizeofSTy t
+ allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
+ pokeByteOff ptr off bufptr
+ pokeShape ptr (off + 8) n sh
+
+ pokeByteOff @Word64 bufptr 0 (2 ^ 63)
+
+ let loop i
+ | i == shapeSize sh = k
+ | otherwise =
+ serialise t (vec V.! i) bufptr (8 + i * eltsz) $
+ loop (i+1)
+ loop 0
+ (STScal sty, x) -> case sty of
+ STI32 -> pokeByteOff ptr off (x :: Int32) >> k
+ STI64 -> pokeByteOff ptr off (x :: Int64) >> k
+ STF32 -> pokeByteOff ptr off (x :: Float) >> k
+ STF64 -> pokeByteOff ptr off (x :: Double) >> k
+ STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
+ (STAccum{}, _) -> error "Cannot serialise accumulators"
+
+-- | Assumes that this is called at the correct alignment.
+deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
+deserialise topty ptr off =
+ -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
+ case topty of
+ STNil -> return ()
+ STPair a b -> do
+ x <- deserialise a ptr off
+ y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a))
+ return (x, y)
+ STEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
+ then Left <$> deserialise a ptr (off + alignmentSTy topty)
+ else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
+ STMaybe t -> do
+ tag <- peekByteOff @Word8 ptr off
+ if tag == 0
+ then return Nothing
+ else Just <$> deserialise t ptr (off + alignmentSTy t)
+ STArr n t -> do
+ bufptr <- peekByteOff @(Ptr ()) ptr off
+ sh <- peekShape ptr (off + 8) n
+ refc <- peekByteOff @Word64 bufptr 0
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
+ let eltsz = sizeofSTy t
+ arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz))
+ when (refc < 2 ^ 62) $ free bufptr
+ return arr
+ STScal sty -> case sty of
+ STI32 -> peekByteOff @Int32 ptr off
+ STI64 -> peekByteOff @Int64 ptr off
+ STF32 -> peekByteOff @Float ptr off
+ STF64 -> peekByteOff @Double ptr off
+ STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
+ STAccum{} -> error "Cannot serialise accumulators"
+
+align :: Int -> Int -> Int
+align a off = (off + a - 1) `div` a * a
+
+alignmentSTy :: STy t -> Int
+alignmentSTy = fst . metricsSTy
+
+sizeofSTy :: STy t -> Int
+sizeofSTy = snd . metricsSTy
+
+-- | Returns (alignment, sizeof)
+metricsSTy :: STy t -> (Int, Int)
+metricsSTy STNil = (1, 0)
+metricsSTy (STPair a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, align (max a1 a2) (s1 + s2))
+metricsSTy (STEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STMaybe t) =
+ let (a, s) = metricsSTy t
+ in (a, a + s) -- the union after the tag byte is aligned
+metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n)
+metricsSTy (STScal sty) = case sty of
+ STI32 -> (4, 4)
+ STI64 -> (8, 8)
+ STF32 -> (4, 4)
+ STF64 -> (8, 8)
+ STBool -> (1, 1) -- compiled to uint8_t
+metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
+
+pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
+pokeShape ptr off = go . fromSNat
+ where
+ go :: Int -> Shape n -> IO ()
+ go rank = \case
+ ShNil -> return ()
+ sh `ShCons` n -> do
+ pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64)
+ go (rank - 1) sh
+
+peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n)
+peekShape ptr off = \case
+ SZ -> return ShNil
+ SS n -> ShCons <$> peekShape ptr off n
+ <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8))
+
+compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
+compile' env = \case
+ EVar _ t i -> do
+ let Const var = slistIdx env i
+ incrementVarAlways "var" Increment t var
+ return $ CELit var
+
+ ELet _ rhs body -> do
+ var <- compileAssign "" env rhs
+ rete <- compile' (Const var `SCons` env) body
+ incrementVarAlways "let" Decrement (typeOf rhs) var
+ return rete
+
+ EPair _ a b -> do
+ name <- emitStruct (STPair (typeOf a) (typeOf b))
+ e1 <- compile' env a
+ e2 <- compile' env b
+ return $ CEStruct name [("a", e1), ("b", e2)]
+
+ EFst _ e -> do
+ let STPair _ t2 = typeOf e
+ e' <- compile' env e
+ case incrementVar "fst" Decrement t2 of
+ Nothing -> return $ CEProj e' "a"
+ Just f -> do var <- genName
+ emit $ SVarDecl True (repSTy (typeOf e)) var e'
+ f (var ++ ".b")
+ return $ CEProj (CELit var) "a"
+
+ ESnd _ e -> do
+ let STPair t1 _ = typeOf e
+ e' <- compile' env e
+ case incrementVar "snd" Decrement t1 of
+ Nothing -> return $ CEProj e' "b"
+ Just f -> do var <- genName
+ emit $ SVarDecl True (repSTy (typeOf e)) var e'
+ f (var ++ ".a")
+ return $ CEProj (CELit var) "b"
+
+ ENil _ -> do
+ name <- emitStruct STNil
+ return $ CEStruct name []
+
+ EInl _ t e -> do
+ name <- emitStruct (STEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "0"), ("l", e1)]
+
+ EInr _ t e -> do
+ name <- emitStruct (STEither t (typeOf e))
+ e2 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("r", e2)]
+
+ ECase _ (EOp _ OIf e) a b -> do
+ e1 <- compile' env e
+ (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you
+ (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SIf e1
+ (stmts2 <> pure (SAsg retvar e2))
+ (stmts3 <> pure (SAsg retvar e3))
+ return (CELit retvar)
+
+ ECase _ e a b -> do
+ let STEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ -- I know those are not variable names, but it's fine, probably
+ (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b
+ ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2
+ <> stmtsRel1
+ <> pure (SAsg retvar e2))
+ (stmts3
+ <> stmtsRel2
+ <> pure (SAsg retvar e3))))
+ return (CELit retvar)
+
+ ENothing _ t -> do
+ name <- emitStruct (STMaybe t)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ EJust _ e -> do
+ name <- emitStruct (STMaybe (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("j", e1)]
+
+ EMaybe _ a b e -> do
+ let STMaybe t = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b
+ ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2
+ <> pure (SAsg retvar e2))
+ (stmts3
+ <> stmtsRel
+ <> pure (SAsg retvar e3))))
+ return (CELit retvar)
+
+ ELNil _ t1 t2 -> do
+ name <- emitStruct (STLEither t1 t2)
+ return $ CEStruct name [("tag", CELit "0")]
+
+ ELInl _ t e -> do
+ name <- emitStruct (STLEither (typeOf e) t)
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
+
+ ELInr _ t e -> do
+ name <- emitStruct (STLEither t (typeOf e))
+ e1 <- compile' env e
+ return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
+
+ ELCase _ e a b c -> do
+ let STLEither t1 t2 = typeOf e
+ e1 <- compile' env e
+ var <- genName
+ (e2, stmts2) <- scope $ compile' env a
+ (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
+ (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
+ ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
+ ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
+ retvar <- genName
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
+ <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
+ (stmts2 <> pure (SAsg retvar e2))
+ (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
+ (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
+ (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
+ return (CELit retvar)
+
+ EConstArr _ n t (Array sh vec) -> do
+ (strname, bufstrname) <- emitArrStruct (STArr n (STScal t))
+ tldname <- genName' "carraybuf"
+ -- Give it a refcount of _half_ the size_t max, so that it can be
+ -- incremented and decremented at will and will "never" reach anything
+ -- where something happens
+ emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++
+ "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++
+ ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
+ return (CEStruct strname
+ [("buf", CEAddrOf (CELit tldname))
+ ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))])
+
+ EBuild _ n esh efun -> do
+ shname <- compileAssign "sh" env esh
+
+ arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
+
+ idxargname <- genName' "ix"
+ (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
+
+ linivar <- genName' "li"
+ ivars <- replicateM (fromSNat n) (genName' "i")
+ emit $ SBlock $
+ pure (SVarDecl False "size_t" linivar (CELit "0"))
+ <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
+ (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx))))
+ | (ivar, dimidx) <- zip ivars [0::Int ..]]
+ (pure (SVarDecl True (repSTy (typeOf esh)) idxargname
+ (shapeTupFromLitVars n ivars))
+ <> funstmts
+ <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval))
+
+ return (CELit arrname)
+
+ -- TODO: actually generate decent code here
+ EMap _ e1 e2 -> do
+ let STArr n _ = typeOf e2
+ compile' env $
+ elet e2 $
+ EBuild ext n (EShape ext (evar IZ)) $
+ elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e1
+
+ EFold1Inner _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+
+ -- let vecwid = case commut of Commut -> 8 :: Int
+ -- Noncommut -> 1
+
+ x0name <- compileAssign "foldx0" env ex0
+ arrname <- compileAssign "foldarr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1i" arrname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, which is
+ -- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
+
+ resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ -- kvar <- if vecwid > 1 then genName' "k" else return ""
+
+ accvar <- genName' "tot"
+ pairvar <- genName' "pair" -- function input
+ (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
+
+ let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
+ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
+
+ pairstrname <- emitStruct (STPair t t)
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
+ <> funStmts
+ <> pure (SAsg accvar funres))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldx0" Decrement t x0name
+ incrementVarAlways "foldarr" Decrement (typeOf earr) arrname
+
+ return (CELit resname)
+
+ ESum1Inner _ e -> do
+ let STArr (SS n) t = typeOf e
+ argname <- compileAssign "sumarg" env e
+
+ zeroRefcountCheck (typeOf e) "sum1i" argname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, like EFold1Inner.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ let vecwid = 8 :: Int
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ kvar <- genName' "k"
+ accvar <- genName' "tot"
+ let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
+ -- we have ScalIsNumeric, so it has 0 and (+) in C
+ [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};"
+ ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $
+ pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];"
+ ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))]
+
+ incrementVarAlways "sum" Decrement (typeOf e) argname
+
+ return (CELit resname)
+
+ EUnit _ e -> do
+ e' <- compile' env e
+ let typ = STArr SZ (typeOf e)
+ strname <- emitStruct typ
+ name <- genName
+ emit $ SVarDecl True strname name (CEStruct strname
+ [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])])
+ emit $ SAsg (name ++ ".buf->refc") (CELit "1")
+ emit $ SAsg (name ++ ".buf->xs[0]") e'
+ return (CELit name)
+
+ EReplicate1Inner _ elen earg -> do
+ let STArr n t = typeOf earg
+ lenname <- compileAssign "replen" env elen
+ argname <- compileAssign "reparg" env earg
+
+ zeroRefcountCheck (typeOf earg) "replicate1i" argname
+
+ shszname <- genName' "shsz"
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray "repl1i" Malloc "rep" (SS n) t
+ (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
+ (compileArrShapeComponents n argname ++ [CELit lenname])
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]")
+ (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]"))
+
+ incrementVarAlways "repl1i" Decrement (typeOf earg) argname
+
+ return (CELit resname)
+
+ EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e
+
+ EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
+
+ EReshape _ dim esh earg -> do
+ let STArr origDim eltty = typeOf earg
+ strname <- emitStruct (STArr dim eltty)
+
+ shname <- compileAssign "reshsh" env esh
+ arrname <- compileAssign "resharg" env earg
+
+ when emitChecks $ do
+ emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
+ printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
+ printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
+ mempty
+
+ return (CEStruct strname
+ [("buf", CEProj (CELit arrname) "buf")
+ ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+
+ -- TODO: actually generate decent code here
+ EZip _ e1 e2 -> do
+ let STArr n _ = typeOf e1
+ compile' env $
+ elet e1 $
+ elet (weakenExpr WSink e2) $
+ EBuild ext n (EShape ext (evar (IS IZ))) $
+ EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ))
+
+ EFold1InnerD1 _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+ STPair _ bty = typeOf efun
+
+ x0name <- compileAssign "foldd1x0" env ex0
+ arrname <- compileAssign "foldd1arr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
+ storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "tot"
+ pairvar <- genName' "pair" -- function input
+ (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
+ arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
+ funresvar <- genName' "res"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+
+ pairstrname <- emitStruct (STPair t t)
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd1x0" Decrement t x0name
+ incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
+
+ strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
+ return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
+
+ EFold1InnerD2 _ commut efun estores ectg -> do
+ let STArr n t2 = typeOf ectg
+ STArr _ bty = typeOf estores
+
+ storesname <- compileAssign "foldd2stores" env estores
+ ctgname <- compileAssign "foldd2ctg" env ectg
+
+ zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
+ outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "acc"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
+ storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
+ ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
+ ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
+ <> ctgeltIncrStmts
+ -- we need to loop in reverse here, but we let jvar run in the
+ -- forward direction so that we can use SLoop. Note jvar is
+ -- reversed in eltidx above
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the accumulator
+ -- and the stores element. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the stores element.
+ storeseltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
+ incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
+
+ strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
+ return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
+
+ EConst _ t x -> return $ CELit $ compileScal True t x
+
+ EIdx0 _ e -> do
+ let STArr _ t = typeOf e
+ arrname <- compileAssign "" env e
+ zeroRefcountCheck (typeOf e) "idx0" arrname
+ name <- genName
+ emit $ SVarDecl True (repSTy t) name
+ (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
+ incrementVarAlways "idx0" Decrement (STArr SZ t) arrname
+ return (CELit name)
+
+ -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
+
+ EIdx _ earr eidx -> do
+ let STArr n t = typeOf earr
+ arrname <- compileAssign "ixarr" env earr
+ zeroRefcountCheck (typeOf earr) "idx" arrname
+ idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
+ then compileAssign "ixix" env eidx
+ else return "" -- won't be used in this case
+
+ when emitChecks $
+ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
+ emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]")))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
+ arrname ++ ".buf); return false;")
+ mempty
+
+ resname <- genName' "ixres"
+ emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname))
+ incrementVarAlways "idxelt" Increment t resname
+ incrementVarAlways "idx" Decrement (STArr n t) arrname
+ return (CELit resname)
+
+ EShape _ e -> do
+ let STArr n _ = typeOf e
+ t = tTup (sreplicate n tIx)
+ _ <- emitStruct t
+ name <- compileAssign "" env e
+ zeroRefcountCheck (typeOf e) "shape" name
+ resname <- genName
+ emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
+ incrementVarAlways "shape" Decrement (typeOf e) name
+ return (CELit resname)
+
+ EOp _ op (EPair _ e1 e2) -> do
+ e1' <- compile' env e1
+ e2' <- compile' env e2
+ compileOpPair op e1' e2'
+
+ EOp _ op e -> do
+ e' <- compile' env e
+ compileOpGeneral op e'
+
+ ECustom _ _ _ _ earg _ _ e1 e2 -> do
+ name1 <- compileAssign "" env e1
+ name2 <- compileAssign "" env e2
+ case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of
+ (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ (mfun1, mfun2) -> do
+ name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg
+ maybe (return ()) ($ name1) mfun1
+ maybe (return ()) ($ name2) mfun2
+ return (CELit name)
+
+ ERecompute _ e -> compile' env e
+
+ EWith _ t e1 e2 -> do
+ actyname <- emitStruct (STAccum t)
+ name1 <- compileAssign "" env e1
+
+ zeroRefcountCheck (typeOf e1) "with" name1
+
+ emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
+ mcopy <- copyForWriting t name1
+ accname <- genName' "accum"
+ emit $ SVarDecl False actyname accname
+ (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
+ emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
+ emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
+
+ e2' <- compile' (Const accname `SCons` env) e2
+
+ resname <- genName' "acret"
+ emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
+ emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
+
+ rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
+ return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
+
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
+ let -- Add a value (s) into an existing accumulation value (d). If a sparse
+ -- component of d is encountered, s is copied there.
+ add :: SMTy a -> String -> String -> CompM ()
+ add SMTNil _ _ = return ()
+ add (SMTPair t1 t2) d s = do
+ add t1 (d++".a") (s++".a")
+ add t2 (d++".b") (s++".b")
+ add (SMTLEither t1 t2) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
+ ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
+ ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ ((if emitChecks
+ then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
+ "&&"
+ (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
+ "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
+ "return false;")
+ mempty)
+ else mempty)
+ -- note: s may have tag 0
+ <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ stmts1
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
+ stmts2 mempty))))
+ add (SMTMaybe t1) d s = do
+ ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
+ ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
+ emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+ (pure (SAsg d (CELit s))
+ <> srcIncrStmts)
+ (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
+ add (SMTArr n t1) d s = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ [0 .. fromSNat n - 1] $ \j -> do
+ emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]"))
+ "!="
+ (CELit (d ++ ".sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
+ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
+ d ++ ".buf" ++
+ concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ ", " ++ s ++ ".buf" ++
+ concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ shsizename <- genName' "acshsz"
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
+ ivar <- genName' "i"
+ ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+ stmts1
+ add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
+
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
+
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
+
+ emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
+ emit $ SVerbatim $ "// compile EAccum end"
+
+ incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
+
+ return $ CEStruct (repSTy STNil) []
+
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
+ EError _ t s -> do
+ let padleft len c s' = replicate (len - length s) c ++ s'
+ escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
+ | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
+ | otherwise -> [c]
+ emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
+ case t of
+ STScal _ -> return (CELit "0")
+ _ -> do
+ name <- emitStruct t
+ return $ CEStruct name []
+
+ EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+
+ EIdx1{} -> error "Compile: not implemented: EIdx1"
+
+compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
+compileAssign prefix env e = do
+ e' <- compile' env e
+ case e' of
+ CELit name -> return name
+ _ -> do
+ name <- genName' prefix
+ emit $ SVarDecl True (repSTy (typeOf e)) name e'
+ return name
+
+data Increment = Increment | Decrement
+ deriving (Show)
+
+-- | Increment reference counts in the components of the given variable.
+incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ())
+incrementVar marker inc ty =
+ let tree = makeArrayTree ty
+ in case tree of ATNoop -> Nothing
+ _ -> Just $ \var -> incrementVar' marker inc var tree
+
+incrementVarAlways :: String -> Increment -> STy a -> String -> CompM ()
+incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty)
+
+data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array)
+ | ATNoop -- ^ don't do anything here
+ | ATProj String ArrayTree -- ^ descend one field deeper
+ | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
+ | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
+ | ATBoth ArrayTree ArrayTree -- ^ do both these paths
+
+smartATProj :: String -> ArrayTree -> ArrayTree
+smartATProj _ ATNoop = ATNoop
+smartATProj field t = ATProj field t
+
+smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
+smartATCondTag ATNoop ATNoop = ATNoop
+smartATCondTag t t' = ATCondTag t t'
+
+smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
+smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
+smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
+
+smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
+smartATBoth ATNoop t = t
+smartATBoth t ATNoop = t
+smartATBoth t t' = ATBoth t t'
+
+makeArrayTree :: STy a -> ArrayTree
+makeArrayTree STNil = ATNoop
+makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
+ (smartATProj "b" (makeArrayTree b))
+makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
+makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
+makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
+makeArrayTree (STScal _) = ATNoop
+makeArrayTree (STAccum _) = ATNoop
+
+incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
+incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
+ case inc of
+ Increment -> do
+ emit $ SVerbatim (path ++ ".buf->refc++;")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
+ Decrement -> do
+ case incrementVar (marker++".elt") Decrement eltty of
+ Nothing ->
+ if debugRefc
+ then do
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
+ else do
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
+ Just f -> do
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ shszvar <- genName' "frshsz"
+ ivar <- genName' "i"
+ ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
+ emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0"))
+ (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path)
+ ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $
+ eltDecrStmts
+ ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"])
+ mempty
+incrementVar' _ _ _ ATNoop = pure ()
+incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t
+incrementVar' marker inc path (ATCondTag t1 t2) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
+incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
+ ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
+ ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
+ ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
+ emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
+ stmts2
+ (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
+ stmts3
+ stmts1))
+incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
+
+toLinearIdx :: SNat n -> String -> String -> CExpr
+toLinearIdx SZ _ _ = CELit "0"
+toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
+toLinearIdx (SS n) arrvar idxvar =
+ CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
+ "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n)))))
+ "+" (CELit (idxvar ++ ".b"))
+
+-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
+-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
+-- fromLinearIdx (SS n) arrvar idxvar = do
+-- name <- genName
+-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]")))
+-- _
+
+data AllocMethod = Malloc | Calloc
+ deriving (Show)
+
+-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
+allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray marker method nameBase rank eltty mshsz shape = do
+ when (length shape /= fromSNat rank) $
+ error "allocArray: shape does not match rank"
+ let arrty = STArr rank eltty
+ strname <- emitStruct arrty
+ arrname <- genName' nameBase
+ shsz <- case mshsz of
+ Just e -> return e
+ Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
+ let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
+ "+"
+ (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
+ emit $ SVarDecl True strname arrname $ CEStruct strname
+ [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
+ Calloc -> CECall "calloc_instr" [nbytesExpr])
+ ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))]
+ emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
+ return arrname
+
+compileShapeQuery :: SNat n -> String -> CExpr
+compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
+compileShapeQuery (SS n) var =
+ CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
+ [("a", compileShapeQuery n var)
+ ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))]
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeSize :: SNat n -> String -> CExpr
+compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
+
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeComponents :: SNat n -> String -> [CExpr]
+compileArrShapeComponents n var =
+ [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+
+indexTupleComponents :: SNat n -> String -> [CExpr]
+indexTupleComponents = \n var -> map CELit (toList (go n var))
+ where
+ go :: SNat n -> String -> Bag String
+ go SZ _ = mempty
+ go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b")
+
+-- | Takes variable names with the innermost dimension on the right.
+shapeTupFromLitVars :: SNat n -> [String] -> CExpr
+shapeTupFromLitVars = \n -> go n . reverse
+ where
+ -- takes variables with the innermost dimension at the _head_
+ go :: SNat n -> [String] -> CExpr
+ go SZ [] = CEStruct (repSTy STNil) []
+ go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
+ go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+
+prodExpr :: [CExpr] -> CExpr
+prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
+
+compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
+compileOpGeneral op e1 = do
+ let unary cop = return @CompM $ CECall cop [e1]
+ let binary cop = do
+ name <- genName
+ emit $ SVarDecl True (repSTy (opt1 op)) name e1
+ return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ ONeg _ -> unary "-"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ ONot -> unary "!"
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIf -> do
+ name <- emitStruct (STEither STNil STNil)
+ _ <- emitStruct STNil
+ return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
+ (CEStruct name [("tag", CELit "1")])
+ ORound64 -> unary "(int64_t)round" -- ew
+ OToFl64 -> unary "(double)"
+ ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
+ OExp STF32 -> unary "expf"
+ OExp STF64 -> unary "exp"
+ OLog STF32 -> unary "logf"
+ OLog STF64 -> unary "log"
+ OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
+
+compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
+compileOpPair op e1 e2 = do
+ let binary cop = return @CompM $ CEBinop e1 cop e2
+ case op of
+ OAdd _ -> binary "+"
+ OMul _ -> binary "*"
+ OLt _ -> binary "<"
+ OLe _ -> binary "<="
+ OEq _ -> binary "=="
+ OAnd -> binary "&&"
+ OOr -> binary "||"
+ OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
+ _ -> error "compileOpPair: got unary operator"
+
+-- | Bool: whether to ensure that the literal itself already has the appropriate type
+compileScal :: Bool -> SScalTy t -> ScalRep t -> String
+compileScal pedantic typ x = case typ of
+ STI32 -> (if pedantic then "(int32_t)" else "") ++ show x
+ STI64 -> (if pedantic then "(int64_t)" else "") ++ show x
+ STF32 -> show x ++ "f"
+ STF64 -> show x
+ STBool -> if x then "1" else "0"
+
+compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr
+compileExtremum nameBase opName operator env e = do
+ let STArr (SS n) t = typeOf e
+ argname <- compileAssign (nameBase ++ "arg") env e
+
+ zeroRefcountCheck (typeOf e) opName argname
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, which is
+ -- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ xvar <- genName
+ redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
+ -- we have ScalIsNumeric, so it has 1 and (<) etc. in C
+ [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]"))
+ ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList
+ [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]"))
+ ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar)
+ ]
+ ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)]
+
+ incrementVarAlways nameBase Decrement (typeOf e) argname
+
+ return (CELit resname)
+
+-- | If this returns Nothing, there was nothing to copy because making a simple
+-- value copy in C already makes it suitable to write to.
+copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
+copyForWriting topty var = case topty of
+ SMTNil -> return Nothing
+
+ SMTPair a b -> do
+ e1 <- copyForWriting a (var ++ ".a")
+ e2 <- copyForWriting b (var ++ ".b")
+ case (e1, e2) of
+ (Nothing, Nothing) -> return Nothing
+ _ -> return $ Just $ CEStruct toptyname
+ [("a", fromMaybe (CELit (var++".a")) e1)
+ ,("b", fromMaybe (CELit (var++".b")) e2)]
+
+ SMTLEither a b -> do
+ (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
+ (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
+ case (e1, e2) of
+ (Nothing, Nothing) -> return Nothing
+ _ -> do
+ name <- genName
+ emit $ SVarDeclUninit toptyname name
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
+ (stmts1
+ <> pure (SAsg name (CEStruct toptyname
+ [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
+ (stmts2
+ <> pure (SAsg name (CEStruct toptyname
+ [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
+ return (Just (CELit name))
+
+ SMTMaybe t -> do
+ (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
+ case e1 of
+ Nothing -> return Nothing
+ Just e1' -> do
+ name <- genName
+ emit $ SVarDeclUninit toptyname name
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
+ (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
+ (stmts1
+ <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
+ return (Just (CELit name))
+
+ -- If there are no nested arrays, we know that a refcount of 1 means that the
+ -- whole thing is owned. Nested arrays have their own refcount, so with
+ -- nesting we'd have to check the refcounts of all the nested arrays _too_;
+ -- let's not do that. Furthermore, no sub-arrays means that the whole thing
+ -- is flat, and we can just memcpy if necessary.
+ SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do
+ name <- genName
+ shszname <- genName' "shsz"
+ emit $ SVarDeclUninit toptyname name
+
+ when debugShapes $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ emit $ SVerbatim $
+ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
+ concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ ");"
+
+ emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
+ (pure (SAsg name (CELit var)))
+ (let shbytes = fromSNat n * 8
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
+ totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
+ in BList
+ [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
+ ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
+ ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
+ ,SAsg (name ++ ".buf->refc") (CELit "1")
+ ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
+ printCExpr 0 databytes ");"])
+ return (Just (CELit name))
+
+ SMTArr n t -> do
+ shszname <- genName' "shsz"
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
+
+ let shbytes = fromSNat n * 8
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
+ totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
+
+ name <- genName
+ emit $ SVarDecl False toptyname name
+ (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
+ emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
+ emit $ SAsg (name ++ ".buf->refc") (CELit "1")
+
+ -- put the arrays in variables to cut short the not-quite-var chain
+ dstvar <- genName' "cpydst"
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
+ srcvar <- genName' "cpysrc"
+ emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
+
+ ivar <- genName' "i"
+
+ (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]")
+ let cpye' = case cpye of
+ Just e -> e
+ Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug"
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ cpystmts
+ <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye')
+
+ return (Just (CELit name))
+
+ SMTScal _ -> return Nothing
+
+ where
+ toptyname = repSTy (fromSMTy topty)
+
+zeroRefcountCheck :: STy t -> String -> String -> CompM ()
+zeroRefcountCheck toptyp opname topvar =
+ when emitChecks $ do
+ mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
+ case mstmts of
+ Nothing -> return ()
+ Just stmts -> forM_ stmts emit
+ where
+ -- | If this returns 'Nothing', no statements need to be generated for this type.
+ go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
+ go STNil _ = empty
+ go (STPair a b) path = do
+ (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
+ return (s1 <> s2)
+ go (STEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
+ go (STMaybe a) path = do
+ ss <- go a (path++".j")
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
+ go (STArr n a) path = do
+ ivar <- genName' "i"
+ ss <- go a (path++".buf->xs["++ivar++"]")
+ shszname <- genName' "shsz"
+ let s1 = SVerbatim $
+ "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
+ "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
+ let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
+ let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
+ return (BList [s1, s2, s3])
+ go STScal{} _ = empty
+ go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
+
+ combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
+ combine (MaybeT a) (MaybeT b) = MaybeT $ do
+ x <- a
+ y <- b
+ return $ case (x, y) of
+ (Nothing, Nothing) -> Nothing
+ (Just x', Nothing) -> Just (x', mempty)
+ (Nothing, Just y') -> Just (mempty, y')
+ (Just x', Just y') -> Just (x', y')
+
+{-# NOINLINE uniqueIdGenRef #-}
+uniqueIdGenRef :: IORef Int
+uniqueIdGenRef = unsafePerformIO $ newIORef 1
+
+compose :: Foldable t => t (a -> a) -> a -> a
+compose = foldr (.) id
+
+showPtr :: Ptr a -> String
+showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) ""
+
+-- | Type-restricted.
+(^) :: Num a => a -> Int -> a
+(^) = (Prelude.^)
+
+foldl0' :: (a -> a -> a) -> a -> [a] -> a
+foldl0' _ x [] = x
+foldl0' f _ l = foldl1' f l