diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Compile.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 1796 |
1 files changed, 0 insertions, 1796 deletions
diff --git a/src/Compile.hs b/src/Compile.hs deleted file mode 100644 index 8627905..0000000 --- a/src/Compile.hs +++ /dev/null @@ -1,1796 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -module Compile (compile, compileStderr) where - -import Control.Applicative (empty) -import Control.Monad (forM_, when, replicateM) -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.Maybe -import Control.Monad.Trans.State.Strict -import Control.Monad.Trans.Writer.CPS -import Data.Bifunctor (first) -import Data.Char (ord) -import Data.Foldable (toList) -import Data.Functor.Const -import qualified Data.Functor.Product as Product -import Data.Functor.Product (Product) -import Data.IORef -import Data.List (foldl1', intersperse, intercalate) -import qualified Data.Map.Strict as Map -import Data.Maybe (fromMaybe) -import qualified Data.Set as Set -import Data.Set (Set) -import Data.Some -import qualified Data.Vector as V -import Foreign -import GHC.Exts (int2Word#, addr2Int#) -import GHC.Num (integerFromWord#) -import GHC.Ptr (Ptr(..)) -import GHC.Stack (HasCallStack) -import Numeric (showHex) -import System.IO (hPutStrLn, stderr) -import System.IO.Error (mkIOError, userErrorType) -import System.IO.Unsafe (unsafePerformIO) - -import Prelude hiding ((^)) -import qualified Prelude - -import Array -import AST -import AST.Pretty (ppSTy, ppExpr) -import AST.Sparse.Types (isDense) -import Compile.Exec -import Data -import Interpreter.Rep -import qualified Util.IdGen as IdGen - - --- In shape and index arrays, the innermost dimension is on the right (last index). - --- TODO: test that I'm properly incrementing and decrementing refcounts in all required places - - --- | Print the compiled AST -debugPrintAST :: Bool; debugPrintAST = toEnum 0 --- | Print the generated C source -debugCSource :: Bool; debugCSource = toEnum 0 --- | Print extra stuff about reference counts of arrays -debugRefc :: Bool; debugRefc = toEnum 0 --- | Print some shape-related information -debugShapes :: Bool; debugShapes = toEnum 0 --- | Print information on allocation -debugAllocs :: Bool; debugAllocs = toEnum 0 --- | Emit extra C code that checks stuff -emitChecks :: Bool; emitChecks = toEnum 0 - --- | Returns compiled function plus compilation output (warnings) -compile :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t), String) -compile = \env expr -> do - codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) - - let (source, offsets) = compileToString codeID env expr - when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" - when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - (lib, compileOutput) <- buildKernel source "kernel" - - let result_type = typeOf expr - result_size = sizeofSTy result_type - - let function val = do - allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) - serialiseArguments args ptr $ do - callKernelFun lib ptr - ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) - when (ok /= 1) $ - ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) - deserialise result_type ptr (koResultOffset offsets) - return (function, compileOutput) - where - serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r - serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = - serialise t arg ptr off $ - serialiseArguments args ptr k - serialiseArguments _ _ k = k - --- | 'compile', but writes any produced C compiler output to stderr. -compileStderr :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t)) -compileStderr env expr = do - (fun, output) <- compile env expr - when (not (null output)) $ - hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" - return fun - - -data StructDecl = StructDecl - String -- ^ name - String -- ^ contents - String -- ^ comment - deriving (Show) - -data Stmt - = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side - | SVarDeclUninit String String -- ^ type, variable name (no initialiser) - | SAsg String CExpr -- ^ variable name, right-hand side - | SBlock (Bag Stmt) - | SIf CExpr (Bag Stmt) (Bag Stmt) - | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>} - | SVerbatim String -- ^ no implicit ';', just printed as-is - deriving (Show) - -data CExpr - = CELit String -- ^ inserted as-is, assumed no parentheses needed - | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` - | CEProj CExpr String -- ^ field projection: expr.field - | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field - | CEAddrOf CExpr -- ^ &expr - | CEIndex CExpr CExpr -- ^ expr[expr] - | CECall String [CExpr] -- ^ function(arg1, ..., argn) - | CEBinop CExpr String CExpr -- ^ expr + expr - | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ (<type>)<expr> - deriving (Show) - -printStructDecl :: StructDecl -> ShowS -printStructDecl (StructDecl name contents comment) = - showString "typedef struct { " . showString contents . showString " } " . showString name - . showString ";" . (if null comment then id else showString (" // " ++ comment)) - -printStmt :: Int -> Stmt -> ShowS -printStmt indent = \case - SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";" - SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") - SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" - SBlock stmts - | null stmts -> showString "{}" - | otherwise -> - showString "{" - . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] - . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") - SIf cond b1 b2 -> - showString "if (" . printCExpr 0 cond . showString ") " - . printStmt indent (SBlock b1) - . (if null b2 then id else showString " else " . printStmt indent (SBlock b2)) - SLoop typ name e1 e2 stmts -> - showString ("for (" ++ typ ++ " " ++ name ++ " = ") - . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2 - . showString ("; " ++ name ++ "++) ") - . printStmt indent (SBlock stmts) - SVerbatim s -> showString s - --- d values: --- * 0: top level --- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) --- * 2-...: various operators (see precTable) --- * 80: address-of operator (&) --- * 98: inside unknown operator --- * 99: left of a field projection --- Unlisted operators are conservatively written with full parentheses. -printCExpr :: Int -> CExpr -> ShowS -printCExpr d = \case - CELit s -> showString s - CEStruct name pairs -> - showParen (d >= 99) $ - showString ("(" ++ name ++ "){") - . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e - | (n, e) <- pairs]) - . showString "}" - CEProj e name -> printCExpr 99 e . showString ("." ++ name) - CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name) - CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e - CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]" - CECall n es -> - showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" - CEBinop e1 n e2 -> - let mprec = Map.lookup n precTable - p = maybe (-1) fst mprec -- precedence of this operator - (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments - in showParen (d > p) $ - printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2 - CEIf e1 e2 e3 -> - showParen (d > 0) $ - printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 - CECast typ e -> - showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e - where - precTable = Map.fromList - [("||", (2, (2, 2))) - ,("&&", (3, (3, 3))) - ,("==", (4, (5, 5))) - ,("!=", (4, (5, 5))) - ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop - ,(">", (5, (6, 6))) - ,("<=", (5, (6, 6))) - ,(">=", (5, (6, 6))) - ,("+", (6, (6, 7))) - ,("-", (6, (6, 7))) - ,("*", (7, (7, 8))) - ,("/", (7, (7, 8))) - ,("%", (7, (7, 8)))] - -repSTy :: STy t -> String -repSTy (STScal st) = case st of - STI32 -> "int32_t" - STI64 -> "int64_t" - STF32 -> "float" - STF64 -> "double" - STBool -> "uint8_t" -repSTy t = genStructName t - -genStructName, genArrBufStructName :: STy t -> String -(genStructName, genArrBufStructName) = - (\t -> "ty_" ++ gen t - ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension - t -> error $ "genArrBufStructName: not an array type: " ++ show t) - where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STLEither a b) = 'L' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen (fromSMTy t) - --- The subtrees contain structs used in the bodies of the structs in this node. -data StructTree = TreeNode [StructDecl] [StructTree] - deriving (Show) - --- | This function generates the actual struct declarations for each of the --- types in our language. It thus implicitly "documents" the layout of the --- types in the C translation. --- --- For accumulation it is important that for struct representations of monoid --- types, the all-zero-bytes value corresponds to the zero value of that type. -buildStructTree :: STy t -> StructTree -buildStructTree topty = case topty of - STNil -> - TreeNode [StructDecl name "" com] [] - STPair a b -> - TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] - [buildStructTree a, buildStructTree b] - STEither a b -> -- 0 -> l, 1 -> r - TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] - [buildStructTree a, buildStructTree b] - STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] - [buildStructTree a, buildStructTree b] - STMaybe t -> -- 0 -> nothing, 1 -> just - TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] - [buildStructTree t] - STArr n t -> - -- The buffer is trailed by a VLA for the actual array data. - -- TODO: no buffer if n = 0 - TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] - [buildStructTree t] - STScal _ -> - TreeNode [] [] - STAccum t -> - TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] - [buildStructTree (fromSMTy t)] - where - name = genStructName topty - com = ppSTy 0 topty - --- State: already-generated (skippable) struct names --- Writer: the structs in declaration order -genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () -genStructTreeW (TreeNode these deps) = do - seen <- lift get - case filter ((`Set.notMember` seen) . nameOf) these of - [] -> pure () - structs -> do - lift $ modify (Set.fromList (map nameOf structs) <>) - mapM_ genStructTreeW deps - tell (BList structs) - where - nameOf (StructDecl name _ _) = name - -genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = - let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys - in toList (evalState (execWriterT m) mempty) - -data CompState = CompState - { csStructs :: Set (Some STy) - , csTopLevelDecls :: Bag String - , csStmts :: Bag Stmt - , csNextId :: Int } - deriving (Show) - -newtype CompM a = CompM (State CompState a) - deriving newtype (Functor, Applicative, Monad) - -runCompM :: CompM a -> (a, CompState) -runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) - -class Monad m => MonadNameGen m where genId :: m Int -instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) -instance MonadNameGen IdGen.IdGen where genId = IdGen.genId -instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) - -genName' :: MonadNameGen m => String -> m String -genName' "" = genName -genName' prefix = (prefix ++) . show <$> genId - -genName :: MonadNameGen m => m String -genName = genName' "x" - -onlyIdGen :: IdGen.IdGen a -> CompM a -onlyIdGen m = CompM $ do - i1 <- gets csNextId - let (res, i2) = IdGen.runIdGen' i1 m - modify (\s -> s { csNextId = i2 }) - return res - -emit :: Stmt -> CompM () -emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } - -scope :: CompM a -> CompM (a, Bag Stmt) -scope m = do - stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) - res <- m - innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) - return (res, innerStmts) - -emitStruct :: STy t -> CompM String -emitStruct ty = CompM $ do - modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } - return (genStructName ty) - --- | Also returns the name of the array buffer struct -emitArrStruct :: STy t -> CompM (String, String) -emitArrStruct ty = CompM $ do - modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } - return (genStructName ty, genArrBufStructName ty) - -emitTLD :: String -> CompM () -emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } - -nameEnv :: SList f env -> SList (Const String) env -nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) - -data KernelOffsets = KernelOffsets - { koArgOffsets :: [Int] -- ^ the function arguments - , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred - , koResultOffset :: Int -- ^ the function result - } - -compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets) -compileToString codeID env expr = - let args = nameEnv env - (res, s) = runCompM (compile' args expr) - structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env)) - - (arg_pairs, arg_metrics) = - unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) - (slistZip env args)) - (arg_offsets, okres_offset) = computeStructOffsets arg_metrics - result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) - - offsets = KernelOffsets - { koArgOffsets = arg_offsets - , koOkResOffset = okres_offset - , koResultOffset = result_offset } - in (,offsets) . ($ "") $ compose - [showString "#include <stdio.h>\n" - ,showString "#include <stdint.h>\n" - ,showString "#include <stdbool.h>\n" - ,showString "#include <inttypes.h>\n" - ,showString "#include <stdlib.h>\n" - ,showString "#include <string.h>\n" - ,showString "#include <math.h>\n\n" - -- PRint-tag - ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n" - - ,compose [printStructDecl sd . showString "\n" | sd <- structs] - ,showString "\n" - - -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative" - ,showString "static void* malloc_instr_fun(size_t n, int line) {\n" - ,showString " void *ptr = malloc(n);\n" - ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n" - else id - ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n" - else id - ,showString " return ptr;\n" - ,showString "}\n" - ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" - ,showString "static void* calloc_instr_fun(size_t n, int line) {\n" - ,showString " void *ptr = calloc(n, 1);\n" - ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n" - else id - ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n" - else id - ,showString " return ptr;\n" - ,showString "}\n" - ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" - ,showString "static void free_instr(void *ptr) {\n" - ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n" - else id - ,showString " free(ptr);\n" - ,showString "}\n\n" - - ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] - - ,showString $ - "static bool typed_kernel(" ++ - repSTy (typeOf expr) ++ " *output" ++ - concatMap (", " ++) - (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ - ") {\n" - ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] - ,showString " *output = " . printCExpr 0 res . showString ";\n" - ,showString " return true;\n" - ,showString "}\n\n" - - ,showString "void kernel(void *data) {\n" - -- Some code here assumes that we're on a 64-bit system, so let's check that - ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n" - ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n" - else id - ,showString $ " const bool success = typed_kernel(" ++ - "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ - concat (map (\((arg, typ), off) -> - ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ " /* " ++ arg ++ " */") - (zip arg_pairs arg_offsets)) ++ - "\n );\n" - ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" - ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" - else id - ,showString "}\n"] - --- | Takes list of metrics (alignment, sizeof). --- Returns (offsets, size of struct). -computeStructOffsets :: [(Int, Int)] -> ([Int], Int) -computeStructOffsets = go 0 0 - where - go off maxal [(al, sz)] = - ([off], align (max maxal al) (off + sz)) - go off maxal ((al, sz) : pairs@((al2,_):_)) = - first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs - go _ _ [] = ([], 0) - --- | Assumes that this is called at the correct alignment. -serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r -serialise topty topval ptr off k = - -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls - case (topty, topval) of - (STNil, ()) -> k - (STPair a b, (x, y)) -> - serialise a x ptr off $ - serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k - (STEither a _, Left x) -> do - pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STEither _ b, Right y) -> do - pokeByteOff ptr off (1 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k - (STLEither _ _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STLEither a _, Just (Left x)) -> do - pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STLEither _ b, Just (Right y)) -> do - pokeByteOff ptr off (2 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k - (STMaybe _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STMaybe t, Just x) -> do - pokeByteOff ptr off (1 :: Word8) - serialise t x ptr (off + alignmentSTy t) k - (STArr n t, Array sh vec) -> do - let eltsz = sizeofSTy t - allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do - when debugRefc $ - hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr - pokeByteOff ptr off bufptr - pokeShape ptr (off + 8) n sh - - pokeByteOff @Word64 bufptr 0 (2 ^ 63) - - let loop i - | i == shapeSize sh = k - | otherwise = - serialise t (vec V.! i) bufptr (8 + i * eltsz) $ - loop (i+1) - loop 0 - (STScal sty, x) -> case sty of - STI32 -> pokeByteOff ptr off (x :: Int32) >> k - STI64 -> pokeByteOff ptr off (x :: Int64) >> k - STF32 -> pokeByteOff ptr off (x :: Float) >> k - STF64 -> pokeByteOff ptr off (x :: Double) >> k - STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k - (STAccum{}, _) -> error "Cannot serialise accumulators" - --- | Assumes that this is called at the correct alignment. -deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) -deserialise topty ptr off = - -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls - case topty of - STNil -> return () - STPair a b -> do - x <- deserialise a ptr off - y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a)) - return (x, y) - STEither a b -> do - tag <- peekByteOff @Word8 ptr off - if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) - then Left <$> deserialise a ptr (off + alignmentSTy topty) - else Right <$> deserialise b ptr (off + alignmentSTy topty) - STLEither a b -> do - tag <- peekByteOff @Word8 ptr off - case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) - 0 -> return Nothing - 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) - 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) - _ -> error "Invalid tag value" - STMaybe t -> do - tag <- peekByteOff @Word8 ptr off - if tag == 0 - then return Nothing - else Just <$> deserialise t ptr (off + alignmentSTy t) - STArr n t -> do - bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape ptr (off + 8) n - refc <- peekByteOff @Word64 bufptr 0 - when debugRefc $ - hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) - when (refc < 2 ^ 62) $ free bufptr - return arr - STScal sty -> case sty of - STI32 -> peekByteOff @Int32 ptr off - STI64 -> peekByteOff @Int64 ptr off - STF32 -> peekByteOff @Float ptr off - STF64 -> peekByteOff @Double ptr off - STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off - STAccum{} -> error "Cannot serialise accumulators" - -align :: Int -> Int -> Int -align a off = (off + a - 1) `div` a * a - -alignmentSTy :: STy t -> Int -alignmentSTy = fst . metricsSTy - -sizeofSTy :: STy t -> Int -sizeofSTy = snd . metricsSTy - --- | Returns (alignment, sizeof) -metricsSTy :: STy t -> (Int, Int) -metricsSTy STNil = (1, 0) -metricsSTy (STPair a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, align (max a1 a2) (s1 + s2)) -metricsSTy (STEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned -metricsSTy (STLEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned -metricsSTy (STMaybe t) = - let (a, s) = metricsSTy t - in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) -metricsSTy (STScal sty) = case sty of - STI32 -> (4, 4) - STI64 -> (8, 8) - STF32 -> (4, 4) - STF64 -> (8, 8) - STBool -> (1, 1) -- compiled to uint8_t -metricsSTy (STAccum t) = metricsSTy (fromSMTy t) - -pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () -pokeShape ptr off = go . fromSNat - where - go :: Int -> Shape n -> IO () - go rank = \case - ShNil -> return () - sh `ShCons` n -> do - pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64) - go (rank - 1) sh - -peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) -peekShape ptr off = \case - SZ -> return ShNil - SS n -> ShCons <$> peekShape ptr off n - <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) - -compile' :: SList (Const String) env -> Ex env t -> CompM CExpr -compile' env = \case - EVar _ t i -> do - let Const var = slistIdx env i - incrementVarAlways "var" Increment t var - return $ CELit var - - ELet _ rhs body -> do - var <- compileAssign "" env rhs - rete <- compile' (Const var `SCons` env) body - incrementVarAlways "let" Decrement (typeOf rhs) var - return rete - - EPair _ a b -> do - name <- emitStruct (STPair (typeOf a) (typeOf b)) - e1 <- compile' env a - e2 <- compile' env b - return $ CEStruct name [("a", e1), ("b", e2)] - - EFst _ e -> do - let STPair _ t2 = typeOf e - e' <- compile' env e - case incrementVar "fst" Decrement t2 of - Nothing -> return $ CEProj e' "a" - Just f -> do var <- genName - emit $ SVarDecl True (repSTy (typeOf e)) var e' - f (var ++ ".b") - return $ CEProj (CELit var) "a" - - ESnd _ e -> do - let STPair t1 _ = typeOf e - e' <- compile' env e - case incrementVar "snd" Decrement t1 of - Nothing -> return $ CEProj e' "b" - Just f -> do var <- genName - emit $ SVarDecl True (repSTy (typeOf e)) var e' - f (var ++ ".a") - return $ CEProj (CELit var) "b" - - ENil _ -> do - name <- emitStruct STNil - return $ CEStruct name [] - - EInl _ t e -> do - name <- emitStruct (STEither (typeOf e) t) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "0"), ("l", e1)] - - EInr _ t e -> do - name <- emitStruct (STEither t (typeOf e)) - e2 <- compile' env e - return $ CEStruct name [("tag", CELit "1"), ("r", e2)] - - ECase _ (EOp _ OIf e) a b -> do - e1 <- compile' env e - (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you - (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SIf e1 - (stmts2 <> pure (SAsg retvar e2)) - (stmts3 <> pure (SAsg retvar e3)) - return (CELit retvar) - - ECase _ e a b -> do - let STEither t1 t2 = typeOf e - e1 <- compile' env e - var <- genName - -- I know those are not variable names, but it's fine, probably - (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a - (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b - ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l") - ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r") - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) - <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 - <> stmtsRel1 - <> pure (SAsg retvar e2)) - (stmts3 - <> stmtsRel2 - <> pure (SAsg retvar e3)))) - return (CELit retvar) - - ENothing _ t -> do - name <- emitStruct (STMaybe t) - return $ CEStruct name [("tag", CELit "0")] - - EJust _ e -> do - name <- emitStruct (STMaybe (typeOf e)) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "1"), ("j", e1)] - - EMaybe _ a b e -> do - let STMaybe t = typeOf e - e1 <- compile' env e - var <- genName - (e2, stmts2) <- scope $ compile' env a - (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b - ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j") - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) - <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 - <> pure (SAsg retvar e2)) - (stmts3 - <> stmtsRel - <> pure (SAsg retvar e3)))) - return (CELit retvar) - - ELNil _ t1 t2 -> do - name <- emitStruct (STLEither t1 t2) - return $ CEStruct name [("tag", CELit "0")] - - ELInl _ t e -> do - name <- emitStruct (STLEither (typeOf e) t) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "1"), ("l", e1)] - - ELInr _ t e -> do - name <- emitStruct (STLEither t (typeOf e)) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "2"), ("r", e1)] - - ELCase _ e a b c -> do - let STLEither t1 t2 = typeOf e - e1 <- compile' env e - var <- genName - (e2, stmts2) <- scope $ compile' env a - (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b - (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c - ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") - ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) - <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 <> pure (SAsg retvar e2)) - (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) - (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) - (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) - return (CELit retvar) - - EConstArr _ n t (Array sh vec) -> do - (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) - tldname <- genName' "carraybuf" - -- Give it a refcount of _half_ the size_t max, so that it can be - -- incremented and decremented at will and will "never" reach anything - -- where something happens - emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ - "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ - ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname - [("buf", CEAddrOf (CELit tldname)) - ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) - - EBuild _ n esh efun -> do - shname <- compileAssign "sh" env esh - - arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname) - - idxargname <- genName' "ix" - (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun - - linivar <- genName' "li" - ivars <- replicateM (fromSNat n) (genName' "i") - emit $ SBlock $ - pure (SVarDecl False "size_t" linivar (CELit "0")) - <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) - | (ivar, dimidx) <- zip ivars [0::Int ..]] - (pure (SVarDecl True (repSTy (typeOf esh)) idxargname - (shapeTupFromLitVars n ivars)) - <> funstmts - <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval)) - - return (CELit arrname) - - -- TODO: actually generate decent code here - EMap _ e1 e2 -> do - let STArr n _ = typeOf e2 - compile' env $ - elet e2 $ - EBuild ext n (EShape ext (evar IZ)) $ - elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e1 - - EFold1Inner _ commut efun ex0 earr -> do - let STArr (SS n) t = typeOf earr - - -- let vecwid = case commut of Commut -> 8 :: Int - -- Noncommut -> 1 - - x0name <- compileAssign "foldx0" env ex0 - arrname <- compileAssign "foldarr" env earr - - zeroRefcountCheck (typeOf earr) "fold1i" arrname - - shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, which is - -- unexpected. But it's exactly what we want, so we do it anyway. - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) - - resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name - - ivar <- genName' "i" - jvar <- genName' "j" - -- kvar <- if vecwid > 1 then genName' "k" else return "" - - accvar <- genName' "tot" - pairvar <- genName' "pair" -- function input - (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun - - let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ - ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit - - pairstrname <- emitStruct (STPair t t) - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - pure (SVarDecl False (repSTy t) accvar (CELit x0name)) - <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) - <> funStmts - <> pure (SAsg accvar funres)) - <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) - - incrementVarAlways "foldx0" Decrement t x0name - incrementVarAlways "foldarr" Decrement (typeOf earr) arrname - - return (CELit resname) - - ESum1Inner _ e -> do - let STArr (SS n) t = typeOf e - argname <- compileAssign "sumarg" env e - - zeroRefcountCheck (typeOf e) "sum1i" argname - - shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, like EFold1Inner. - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - - resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - let vecwid = 8 :: Int - ivar <- genName' "i" - jvar <- genName' "j" - kvar <- genName' "k" - accvar <- genName' "tot" - let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList - -- we have ScalIsNumeric, so it has 0 and (+) in C - [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};" - ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $ - pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ - pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];" - ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $ - pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];" - ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ - pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];" - ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))] - - incrementVarAlways "sum" Decrement (typeOf e) argname - - return (CELit resname) - - EUnit _ e -> do - e' <- compile' env e - let typ = STArr SZ (typeOf e) - strname <- emitStruct typ - name <- genName - emit $ SVarDecl True strname name (CEStruct strname - [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])]) - emit $ SAsg (name ++ ".buf->refc") (CELit "1") - emit $ SAsg (name ++ ".buf->xs[0]") e' - return (CELit name) - - EReplicate1Inner _ elen earg -> do - let STArr n t = typeOf earg - lenname <- compileAssign "replen" env elen - argname <- compileAssign "reparg" env earg - - zeroRefcountCheck (typeOf earg) "replicate1i" argname - - shszname <- genName' "shsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - - resname <- allocArray "repl1i" Malloc "rep" (SS n) t - (Just (CEBinop (CELit shszname) "*" (CELit lenname))) - (compileArrShapeComponents n argname ++ [CELit lenname]) - - ivar <- genName' "i" - jvar <- genName' "j" - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]") - (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]")) - - incrementVarAlways "repl1i" Decrement (typeOf earg) argname - - return (CELit resname) - - EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e - - EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e - - EReshape _ dim esh earg -> do - let STArr origDim eltty = typeOf earg - strname <- emitStruct (STArr dim eltty) - - shname <- compileAssign "reshsh" env esh - arrname <- compileAssign "resharg" env earg - - when emitChecks $ do - emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ - printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ - printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") - mempty - - return (CEStruct strname - [("buf", CEProj (CELit arrname) "buf") - ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) - - -- TODO: actually generate decent code here - EZip _ e1 e2 -> do - let STArr n _ = typeOf e1 - compile' env $ - elet e1 $ - elet (weakenExpr WSink e2) $ - EBuild ext n (EShape ext (evar (IS IZ))) $ - EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) - - EFold1InnerD1 _ commut efun ex0 earr -> do - let STArr (SS n) t = typeOf earr - STPair _ bty = typeOf efun - - x0name <- compileAssign "foldd1x0" env ex0 - arrname <- compileAssign "foldd1arr" env earr - - zeroRefcountCheck (typeOf earr) "fold1iD1" arrname - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - shsz1name <- genName' "shszN" - emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape - shsz2name <- genName' "shszSN" - emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) - - resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) - storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) - - ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name - - ivar <- genName' "i" - jvar <- genName' "j" - - accvar <- genName' "tot" - pairvar <- genName' "pair" -- function input - (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun - let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar - arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" - funresvar <- genName' "res" - ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit - - pairstrname <- emitStruct (STPair t t) - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ - pure (SVarDecl False (repSTy t) accvar (CELit x0name)) - <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) - <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) - <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) - <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) - <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) - - incrementVarAlways "foldd1x0" Decrement t x0name - incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname - - strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) - return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) - - EFold1InnerD2 _ commut efun estores ectg -> do - let STArr n t2 = typeOf ectg - STArr _ bty = typeOf estores - - storesname <- compileAssign "foldd2stores" env estores - ctgname <- compileAssign "foldd2ctg" env ectg - - zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - shsz1name <- genName' "shszN" - emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape - shsz2name <- genName' "shszSN" - emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) - - x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) - outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) - - ivar <- genName' "i" - jvar <- genName' "j" - - accvar <- genName' "acc" - let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar - storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" - ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" - (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun - funresvar <- genName' "res" - ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit - ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit - - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ - pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) - <> ctgeltIncrStmts - -- we need to loop in reverse here, but we let jvar run in the - -- forward direction so that we can use SLoop. Note jvar is - -- reversed in eltidx above - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the accumulator - -- and the stores element. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the stores element. - storeseltIncrStmts - <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) - <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) - <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) - <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) - - incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname - incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname - - strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) - return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) - - EConst _ t x -> return $ CELit $ compileScal True t x - - EIdx0 _ e -> do - let STArr _ t = typeOf e - arrname <- compileAssign "" env e - zeroRefcountCheck (typeOf e) "idx0" arrname - name <- genName - emit $ SVarDecl True (repSTy t) name - (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) - incrementVarAlways "idx0" Decrement (STArr SZ t) arrname - return (CELit name) - - -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) - - EIdx _ earr eidx -> do - let STArr n t = typeOf earr - arrname <- compileAssign "ixarr" env earr - zeroRefcountCheck (typeOf earr) "idx" arrname - idxname <- if fromSNat n > 0 -- prevent an unused-varable warning - then compileAssign "ixix" env eidx - else return "" -- won't be used in this case - - when emitChecks $ - forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> - emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ - arrname ++ ".buf); return false;") - mempty - - resname <- genName' "ixres" - emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname)) - incrementVarAlways "idxelt" Increment t resname - incrementVarAlways "idx" Decrement (STArr n t) arrname - return (CELit resname) - - EShape _ e -> do - let STArr n _ = typeOf e - t = tTup (sreplicate n tIx) - _ <- emitStruct t - name <- compileAssign "" env e - zeroRefcountCheck (typeOf e) "shape" name - resname <- genName - emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) - incrementVarAlways "shape" Decrement (typeOf e) name - return (CELit resname) - - EOp _ op (EPair _ e1 e2) -> do - e1' <- compile' env e1 - e2' <- compile' env e2 - compileOpPair op e1' e2' - - EOp _ op e -> do - e' <- compile' env e - compileOpGeneral op e' - - ECustom _ _ _ _ earg _ _ e1 e2 -> do - name1 <- compileAssign "" env e1 - name2 <- compileAssign "" env e2 - case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of - (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg - (mfun1, mfun2) -> do - name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg - maybe (return ()) ($ name1) mfun1 - maybe (return ()) ($ name2) mfun2 - return (CELit name) - - ERecompute _ e -> compile' env e - - EWith _ t e1 e2 -> do - actyname <- emitStruct (STAccum t) - name1 <- compileAssign "" env e1 - - zeroRefcountCheck (typeOf e1) "with" name1 - - emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" - mcopy <- copyForWriting t name1 - accname <- genName' "accum" - emit $ SVarDecl False actyname accname - (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) - emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) - emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." - - e2' <- compile' (Const accname `SCons` env) e2 - - resname <- genName' "acret" - emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) - emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" - - rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) - return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - - EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do - let -- Add a value (s) into an existing accumulation value (d). If a sparse - -- component of d is encountered, s is copied there. - add :: SMTy a -> String -> String -> CompM () - add SMTNil _ _ = return () - add (SMTPair t1 t2) d s = do - add t1 (d++".a") (s++".a") - add t2 (d++".b") (s++".b") - add (SMTLEither t1 t2) d s = do - ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s - ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") - ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s)) - <> srcIncrStmts) - ((if emitChecks - then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) - "&&" - (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ - "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ - "return false;") - mempty) - else mempty) - -- note: s may have tag 0 - <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - stmts1 - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) - stmts2 mempty)))) - add (SMTMaybe t1) d s = do - ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s - ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s)) - <> srcIncrStmts) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) - add (SMTArr n t1) d s = do - when emitChecks $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ [0 .. fromSNat n - 1] $ \j -> do - emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) - "!=" - (CELit (d ++ ".sh[" ++ show j ++ "]"))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ - "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ - d ++ ".buf" ++ - concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - ", " ++ s ++ ".buf" ++ - concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - "); " ++ - "return false;") - mempty - - shsizename <- genName' "acshsz" - emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) - ivar <- genName' "i" - ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) - stmts1 - add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - - let -- | Dereference an accumulation value and add a given value to that - -- position. Sparse components encountered along the way are - -- initialised before proceeding downwards. - -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) - accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () - accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend - - accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do - when emitChecks $ do - emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ - "return false;") - mempty - accumRef ta prj' (v++".l") i addend - accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do - when emitChecks $ do - emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ - "return false;") - mempty - accumRef tb prj' (v++".r") i addend - - accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - when emitChecks $ do - emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ - "return false;") - mempty - accumRef tj prj' (v++".j") i addend - - accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do - when emitChecks $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip [0::Int ..] - (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do - let a .||. b = CEBinop a "||" b - emit $ SIf (CEBinop ixcomp "<" (CELit "0") - .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ - v ++ ".buf" ++ - concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ - "); " ++ - "return false;") - mempty - - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend - - nameidx <- compileAssign "acidx" env eidx - nameval <- compileAssign "acval" env eval - nameacc <- compileAssign "acac" env eacc - - emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - accumRef t prj (nameacc++".buf->ac") nameidx nameval - emit $ SVerbatim $ "// compile EAccum end" - - incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval - - return $ CEStruct (repSTy STNil) [] - - EAccum{} -> - error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" - - EError _ t s -> do - let padleft len c s' = replicate (len - length s) c ++ s' - escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] - | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") - | otherwise -> [c] - emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;" - case t of - STScal _ -> return (CELit "0") - _ -> do - name <- emitStruct t - return $ CEStruct name [] - - EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - - EIdx1{} -> error "Compile: not implemented: EIdx1" - -compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String -compileAssign prefix env e = do - e' <- compile' env e - case e' of - CELit name -> return name - _ -> do - name <- genName' prefix - emit $ SVarDecl True (repSTy (typeOf e)) name e' - return name - -data Increment = Increment | Decrement - deriving (Show) - --- | Increment reference counts in the components of the given variable. -incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ()) -incrementVar marker inc ty = - let tree = makeArrayTree ty - in case tree of ATNoop -> Nothing - _ -> Just $ \var -> incrementVar' marker inc var tree - -incrementVarAlways :: String -> Increment -> STy a -> String -> CompM () -incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty) - -data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array) - | ATNoop -- ^ don't do anything here - | ATProj String ArrayTree -- ^ descend one field deeper - | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second - | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 - | ATBoth ArrayTree ArrayTree -- ^ do both these paths - -smartATProj :: String -> ArrayTree -> ArrayTree -smartATProj _ ATNoop = ATNoop -smartATProj field t = ATProj field t - -smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree -smartATCondTag ATNoop ATNoop = ATNoop -smartATCondTag t t' = ATCondTag t t' - -smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree -smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop -smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 - -smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree -smartATBoth ATNoop t = t -smartATBoth t ATNoop = t -smartATBoth t t' = ATBoth t t' - -makeArrayTree :: STy a -> ArrayTree -makeArrayTree STNil = ATNoop -makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) - (smartATProj "b" (makeArrayTree b)) -makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop - (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) -makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) -makeArrayTree (STArr n t) = ATArray (Some n) (Some t) -makeArrayTree (STScal _) = ATNoop -makeArrayTree (STAccum _) = ATNoop - -incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () -incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = - case inc of - Increment -> do - emit $ SVerbatim (path ++ ".buf->refc++;") - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" - Decrement -> do - case incrementVar (marker++".elt") Decrement eltty of - Nothing -> - if debugRefc - then do - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" - emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" - else do - emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);" - Just f -> do - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" - shszvar <- genName' "frshsz" - ivar <- genName' "i" - ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") - emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0")) - (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path) - ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $ - eltDecrStmts - ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"]) - mempty -incrementVar' _ _ _ ATNoop = pure () -incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t -incrementVar' marker inc path (ATCondTag t1 t2) = do - ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 - ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 - emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 -incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do - ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 - ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 - ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 - emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) - stmts2 - (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) - stmts3 - stmts1)) -incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 - -toLinearIdx :: SNat n -> String -> String -> CExpr -toLinearIdx SZ _ _ = CELit "0" -toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") -toLinearIdx (SS n) arrvar idxvar = - CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) - "+" (CELit (idxvar ++ ".b")) - --- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr --- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] --- fromLinearIdx (SS n) arrvar idxvar = do --- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) --- _ - -data AllocMethod = Malloc | Calloc - deriving (Show) - --- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String -allocArray marker method nameBase rank eltty mshsz shape = do - when (length shape /= fromSNat rank) $ - error "allocArray: shape does not match rank" - let arrty = STArr rank eltty - strname <- emitStruct arrty - arrname <- genName' nameBase - shsz <- case mshsz of - Just e -> return e - Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape) - let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8))) - "+" - (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) - emit $ SVarDecl True strname arrname $ CEStruct strname - [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr]) - ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] - emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" - return arrname - -compileShapeQuery :: SNat n -> String -> CExpr -compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] -compileShapeQuery (SS n) var = - CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) - [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] - --- | Takes a variable name for the array, not the buffer. -compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) - --- | Takes a variable name for the array, not the buffer. -compileArrShapeComponents :: SNat n -> String -> [CExpr] -compileArrShapeComponents n var = - [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] - -indexTupleComponents :: SNat n -> String -> [CExpr] -indexTupleComponents = \n var -> map CELit (toList (go n var)) - where - go :: SNat n -> String -> Bag String - go SZ _ = mempty - go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b") - --- | Takes variable names with the innermost dimension on the right. -shapeTupFromLitVars :: SNat n -> [String] -> CExpr -shapeTupFromLitVars = \n -> go n . reverse - where - -- takes variables with the innermost dimension at the _head_ - go :: SNat n -> [String] -> CExpr - go SZ [] = CEStruct (repSTy STNil) [] - go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] - go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" - -prodExpr :: [CExpr] -> CExpr -prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") - -compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr -compileOpGeneral op e1 = do - let unary cop = return @CompM $ CECall cop [e1] - let binary cop = do - name <- genName - emit $ SVarDecl True (repSTy (opt1 op)) name e1 - return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") - case op of - OAdd _ -> binary "+" - OMul _ -> binary "*" - ONeg _ -> unary "-" - OLt _ -> binary "<" - OLe _ -> binary "<=" - OEq _ -> binary "==" - ONot -> unary "!" - OAnd -> binary "&&" - OOr -> binary "||" - OIf -> do - name <- emitStruct (STEither STNil STNil) - _ <- emitStruct STNil - return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) - (CEStruct name [("tag", CELit "1")]) - ORound64 -> unary "(int64_t)round" -- ew - OToFl64 -> unary "(double)" - ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 - OExp STF32 -> unary "expf" - OExp STF64 -> unary "exp" - OLog STF32 -> unary "logf" - OLog STF64 -> unary "log" - OIDiv _ -> binary "/" - OMod _ -> binary "%" - -compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr -compileOpPair op e1 e2 = do - let binary cop = return @CompM $ CEBinop e1 cop e2 - case op of - OAdd _ -> binary "+" - OMul _ -> binary "*" - OLt _ -> binary "<" - OLe _ -> binary "<=" - OEq _ -> binary "==" - OAnd -> binary "&&" - OOr -> binary "||" - OIDiv _ -> binary "/" - OMod _ -> binary "%" - _ -> error "compileOpPair: got unary operator" - --- | Bool: whether to ensure that the literal itself already has the appropriate type -compileScal :: Bool -> SScalTy t -> ScalRep t -> String -compileScal pedantic typ x = case typ of - STI32 -> (if pedantic then "(int32_t)" else "") ++ show x - STI64 -> (if pedantic then "(int64_t)" else "") ++ show x - STF32 -> show x ++ "f" - STF64 -> show x - STBool -> if x then "1" else "0" - -compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr -compileExtremum nameBase opName operator env e = do - let STArr (SS n) t = typeOf e - argname <- compileAssign (nameBase ++ "arg") env e - - zeroRefcountCheck (typeOf e) opName argname - - shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, which is - -- unexpected. But it's exactly what we want, so we do it anyway. - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - - resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" - - ivar <- genName' "i" - jvar <- genName' "j" - xvar <- genName - redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList - -- we have ScalIsNumeric, so it has 1 and (<) etc. in C - [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]")) - ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList - [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]")) - ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar) - ] - ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)] - - incrementVarAlways nameBase Decrement (typeOf e) argname - - return (CELit resname) - --- | If this returns Nothing, there was nothing to copy because making a simple --- value copy in C already makes it suitable to write to. -copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) -copyForWriting topty var = case topty of - SMTNil -> return Nothing - - SMTPair a b -> do - e1 <- copyForWriting a (var ++ ".a") - e2 <- copyForWriting b (var ++ ".b") - case (e1, e2) of - (Nothing, Nothing) -> return Nothing - _ -> return $ Just $ CEStruct toptyname - [("a", fromMaybe (CELit (var++".a")) e1) - ,("b", fromMaybe (CELit (var++".b")) e2)] - - SMTLEither a b -> do - (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") - (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") - case (e1, e2) of - (Nothing, Nothing) -> return Nothing - _ -> do - name <- genName - emit $ SVarDeclUninit toptyname name - emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (stmts1 - <> pure (SAsg name (CEStruct toptyname - [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) - (stmts2 - <> pure (SAsg name (CEStruct toptyname - [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) - return (Just (CELit name)) - - SMTMaybe t -> do - (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") - case e1 of - Nothing -> return Nothing - Just e1' -> do - name <- genName - emit $ SVarDeclUninit toptyname name - emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) - (stmts1 - <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) - return (Just (CELit name)) - - -- If there are no nested arrays, we know that a refcount of 1 means that the - -- whole thing is owned. Nested arrays have their own refcount, so with - -- nesting we'd have to check the refcounts of all the nested arrays _too_; - -- let's not do that. Furthermore, no sub-arrays means that the whole thing - -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do - name <- genName - shszname <- genName' "shsz" - emit $ SVarDeclUninit toptyname name - - when debugShapes $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - emit $ SVerbatim $ - "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ - ");" - - emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) - (pure (SAsg name (CELit var))) - (let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) - totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes - in BList - [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" - ,SAsg (name ++ ".buf->refc") (CELit "1") - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ - printCExpr 0 databytes ");"]) - return (Just (CELit name)) - - SMTArr n t -> do - shszname <- genName' "shsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - - let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) - totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes - - name <- genName - emit $ SVarDecl False toptyname name - (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" - emit $ SAsg (name ++ ".buf->refc") (CELit "1") - - -- put the arrays in variables to cut short the not-quite-var chain - dstvar <- genName' "cpydst" - emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) - srcvar <- genName' "cpysrc" - emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) - - ivar <- genName' "i" - - (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]") - let cpye' = case cpye of - Just e -> e - Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" - - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - cpystmts - <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') - - return (Just (CELit name)) - - SMTScal _ -> return Nothing - - where - toptyname = repSTy (fromSMTy topty) - -zeroRefcountCheck :: STy t -> String -> String -> CompM () -zeroRefcountCheck toptyp opname topvar = - when emitChecks $ do - mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) - case mstmts of - Nothing -> return () - Just stmts -> forM_ stmts emit - where - -- | If this returns 'Nothing', no statements need to be generated for this type. - go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) - go STNil _ = empty - go (STPair a b) path = do - (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) - return (s1 <> s2) - go (STEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 - go (STLEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ - SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) - s1 - (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) - s2 - mempty)) - go (STMaybe a) path = do - ss <- go a (path++".j") - return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty - go (STArr n a) path = do - ivar <- genName' "i" - ss <- go a (path++".buf->xs["++ivar++"]") - shszname <- genName' "shsz" - let s1 = SVerbatim $ - "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ - "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++ - "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }" - let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) - let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss - return (BList [s1, s2, s3]) - go STScal{} _ = empty - go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" - - combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) - combine (MaybeT a) (MaybeT b) = MaybeT $ do - x <- a - y <- b - return $ case (x, y) of - (Nothing, Nothing) -> Nothing - (Just x', Nothing) -> Just (x', mempty) - (Nothing, Just y') -> Just (mempty, y') - (Just x', Just y') -> Just (x', y') - -{-# NOINLINE uniqueIdGenRef #-} -uniqueIdGenRef :: IORef Int -uniqueIdGenRef = unsafePerformIO $ newIORef 1 - -compose :: Foldable t => t (a -> a) -> a -> a -compose = foldr (.) id - -showPtr :: Ptr a -> String -showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" - --- | Type-restricted. -(^) :: Num a => a -> Int -> a -(^) = (Prelude.^) - -foldl0' :: (a -> a -> a) -> a -> [a] -> a -foldl0' _ x [] = x -foldl0' f _ l = foldl1' f l |
