diff options
Diffstat (limited to 'src/CHAD/Compile.hs')
| -rw-r--r-- | src/CHAD/Compile.hs | 1796 |
1 files changed, 1796 insertions, 0 deletions
diff --git a/src/CHAD/Compile.hs b/src/CHAD/Compile.hs new file mode 100644 index 0000000..5b71651 --- /dev/null +++ b/src/CHAD/Compile.hs @@ -0,0 +1,1796 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Compile (compile, compileStderr) where + +import Control.Applicative (empty) +import Control.Monad (forM_, when, replicateM) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Maybe +import Control.Monad.Trans.State.Strict +import Control.Monad.Trans.Writer.CPS +import Data.Bifunctor (first) +import Data.Char (ord) +import Data.Foldable (toList) +import Data.Functor.Const +import qualified Data.Functor.Product as Product +import Data.Functor.Product (Product) +import Data.IORef +import Data.List (foldl1', intersperse, intercalate) +import qualified Data.Map.Strict as Map +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) +import Data.Some +import qualified Data.Vector as V +import Foreign +import GHC.Exts (int2Word#, addr2Int#) +import GHC.Num (integerFromWord#) +import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) +import Numeric (showHex) +import System.IO (hPutStrLn, stderr) +import System.IO.Error (mkIOError, userErrorType) +import System.IO.Unsafe (unsafePerformIO) + +import Prelude hiding ((^)) +import qualified Prelude + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty (ppSTy, ppExpr) +import CHAD.AST.Sparse.Types (isDense) +import CHAD.Compile.Exec +import CHAD.Data +import CHAD.Interpreter.Rep +import qualified CHAD.Util.IdGen as IdGen + + +-- In shape and index arrays, the innermost dimension is on the right (last index). + +-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places + + +-- | Print the compiled AST +debugPrintAST :: Bool; debugPrintAST = toEnum 0 +-- | Print the generated C source +debugCSource :: Bool; debugCSource = toEnum 0 +-- | Print extra stuff about reference counts of arrays +debugRefc :: Bool; debugRefc = toEnum 0 +-- | Print some shape-related information +debugShapes :: Bool; debugShapes = toEnum 0 +-- | Print information on allocation +debugAllocs :: Bool; debugAllocs = toEnum 0 +-- | Emit extra C code that checks stuff +emitChecks :: Bool; emitChecks = toEnum 0 + +-- | Returns compiled function plus compilation output (warnings) +compile :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t), String) +compile = \env expr -> do + codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) + + let (source, offsets) = compileToString codeID env expr + when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" + when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" + (lib, compileOutput) <- buildKernel source "kernel" + + let result_type = typeOf expr + result_size = sizeofSTy result_type + + let function val = do + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) + serialiseArguments args ptr $ do + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) + return (function, compileOutput) + where + serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r + serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = + serialise t arg ptr off $ + serialiseArguments args ptr k + serialiseArguments _ _ k = k + +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do + (fun, output) <- compile env expr + when (not (null output)) $ + hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + + +data StructDecl = StructDecl + String -- ^ name + String -- ^ contents + String -- ^ comment + deriving (Show) + +data Stmt + = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side + | SVarDeclUninit String String -- ^ type, variable name (no initialiser) + | SAsg String CExpr -- ^ variable name, right-hand side + | SBlock (Bag Stmt) + | SIf CExpr (Bag Stmt) (Bag Stmt) + | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>} + | SVerbatim String -- ^ no implicit ';', just printed as-is + deriving (Show) + +data CExpr + = CELit String -- ^ inserted as-is, assumed no parentheses needed + | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` + | CEProj CExpr String -- ^ field projection: expr.field + | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field + | CEAddrOf CExpr -- ^ &expr + | CEIndex CExpr CExpr -- ^ expr[expr] + | CECall String [CExpr] -- ^ function(arg1, ..., argn) + | CEBinop CExpr String CExpr -- ^ expr + expr + | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr + | CECast String CExpr -- ^ (<type>)<expr> + deriving (Show) + +printStructDecl :: StructDecl -> ShowS +printStructDecl (StructDecl name contents comment) = + showString "typedef struct { " . showString contents . showString " } " . showString name + . showString ";" . (if null comment then id else showString (" // " ++ comment)) + +printStmt :: Int -> Stmt -> ShowS +printStmt indent = \case + SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";" + SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") + SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" + SBlock stmts + | null stmts -> showString "{}" + | otherwise -> + showString "{" + . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] + . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") + SIf cond b1 b2 -> + showString "if (" . printCExpr 0 cond . showString ") " + . printStmt indent (SBlock b1) + . (if null b2 then id else showString " else " . printStmt indent (SBlock b2)) + SLoop typ name e1 e2 stmts -> + showString ("for (" ++ typ ++ " " ++ name ++ " = ") + . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2 + . showString ("; " ++ name ++ "++) ") + . printStmt indent (SBlock stmts) + SVerbatim s -> showString s + +-- d values: +-- * 0: top level +-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) +-- * 2-...: various operators (see precTable) +-- * 80: address-of operator (&) +-- * 98: inside unknown operator +-- * 99: left of a field projection +-- Unlisted operators are conservatively written with full parentheses. +printCExpr :: Int -> CExpr -> ShowS +printCExpr d = \case + CELit s -> showString s + CEStruct name pairs -> + showParen (d >= 99) $ + showString ("(" ++ name ++ "){") + . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e + | (n, e) <- pairs]) + . showString "}" + CEProj e name -> printCExpr 99 e . showString ("." ++ name) + CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name) + CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e + CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]" + CECall n es -> + showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" + CEBinop e1 n e2 -> + let mprec = Map.lookup n precTable + p = maybe (-1) fst mprec -- precedence of this operator + (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments + in showParen (d > p) $ + printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2 + CEIf e1 e2 e3 -> + showParen (d > 0) $ + printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 + CECast typ e -> + showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e + where + precTable = Map.fromList + [("||", (2, (2, 2))) + ,("&&", (3, (3, 3))) + ,("==", (4, (5, 5))) + ,("!=", (4, (5, 5))) + ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop + ,(">", (5, (6, 6))) + ,("<=", (5, (6, 6))) + ,(">=", (5, (6, 6))) + ,("+", (6, (6, 7))) + ,("-", (6, (6, 7))) + ,("*", (7, (7, 8))) + ,("/", (7, (7, 8))) + ,("%", (7, (7, 8)))] + +repSTy :: STy t -> String +repSTy (STScal st) = case st of + STI32 -> "int32_t" + STI64 -> "int64_t" + STF32 -> "float" + STF64 -> "double" + STBool -> "uint8_t" +repSTy t = genStructName t + +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) + +-- | This function generates the actual struct declarations for each of the +-- types in our language. It thus implicitly "documents" the layout of the +-- types in the C translation. +-- +-- For accumulation it is important that for struct representations of monoid +-- types, the all-zero-bytes value corresponds to the zero value of that type. +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of + STNil -> + TreeNode [StructDecl name "" com] [] + STPair a b -> + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] + STEither a b -> -- 0 -> l, 1 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STMaybe t -> -- 0 -> nothing, 1 -> just + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] + STArr n t -> + -- The buffer is trailed by a VLA for the actual array data. + -- TODO: no buffer if n = 0 + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] + STScal _ -> + TreeNode [] [] + STAccum t -> + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] + where + name = genStructName topty + com = ppSTy 0 topty + +-- State: already-generated (skippable) struct names +-- Writer: the structs in declaration order +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name + +genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) + +data CompState = CompState + { csStructs :: Set (Some STy) + , csTopLevelDecls :: Bag String + , csStmts :: Bag Stmt + , csNextId :: Int } + deriving (Show) + +newtype CompM a = CompM (State CompState a) + deriving newtype (Functor, Applicative, Monad) + +runCompM :: CompM a -> (a, CompState) +runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) + +class Monad m => MonadNameGen m where genId :: m Int +instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +instance MonadNameGen IdGen.IdGen where genId = IdGen.genId +instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) + +genName' :: MonadNameGen m => String -> m String +genName' "" = genName +genName' prefix = (prefix ++) . show <$> genId + +genName :: MonadNameGen m => m String +genName = genName' "x" + +onlyIdGen :: IdGen.IdGen a -> CompM a +onlyIdGen m = CompM $ do + i1 <- gets csNextId + let (res, i2) = IdGen.runIdGen' i1 m + modify (\s -> s { csNextId = i2 }) + return res + +emit :: Stmt -> CompM () +emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } + +scope :: CompM a -> CompM (a, Bag Stmt) +scope m = do + stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) + res <- m + innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) + return (res, innerStmts) + +emitStruct :: STy t -> CompM String +emitStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty) + +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + +emitTLD :: String -> CompM () +emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } + +nameEnv :: SList f env -> SList (Const String) env +nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) + +data KernelOffsets = KernelOffsets + { koArgOffsets :: [Int] -- ^ the function arguments + , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred + , koResultOffset :: Int -- ^ the function result + } + +compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets) +compileToString codeID env expr = + let args = nameEnv env + (res, s) = runCompM (compile' args expr) + structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env)) + + (arg_pairs, arg_metrics) = + unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) + (slistZip env args)) + (arg_offsets, okres_offset) = computeStructOffsets arg_metrics + result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) + + offsets = KernelOffsets + { koArgOffsets = arg_offsets + , koOkResOffset = okres_offset + , koResultOffset = result_offset } + in (,offsets) . ($ "") $ compose + [showString "#include <stdio.h>\n" + ,showString "#include <stdint.h>\n" + ,showString "#include <stdbool.h>\n" + ,showString "#include <inttypes.h>\n" + ,showString "#include <stdlib.h>\n" + ,showString "#include <string.h>\n" + ,showString "#include <math.h>\n\n" + -- PRint-tag + ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n" + + ,compose [printStructDecl sd . showString "\n" | sd <- structs] + ,showString "\n" + + -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative" + ,showString "static void* malloc_instr_fun(size_t n, int line) {\n" + ,showString " void *ptr = malloc(n);\n" + ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n" + else id + ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n" + else id + ,showString " return ptr;\n" + ,showString "}\n" + ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" + ,showString "static void* calloc_instr_fun(size_t n, int line) {\n" + ,showString " void *ptr = calloc(n, 1);\n" + ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n" + else id + ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n" + else id + ,showString " return ptr;\n" + ,showString "}\n" + ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" + ,showString "static void free_instr(void *ptr) {\n" + ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n" + else id + ,showString " free(ptr);\n" + ,showString "}\n\n" + + ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] + + ,showString $ + "static bool typed_kernel(" ++ + repSTy (typeOf expr) ++ " *output" ++ + concatMap (", " ++) + (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ + ") {\n" + ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] + ,showString " *output = " . printCExpr 0 res . showString ";\n" + ,showString " return true;\n" + ,showString "}\n\n" + + ,showString "void kernel(void *data) {\n" + -- Some code here assumes that we're on a 64-bit system, so let's check that + ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n" + ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n" + else id + ,showString $ " const bool success = typed_kernel(" ++ + "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ + concat (map (\((arg, typ), off) -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + (zip arg_pairs arg_offsets)) ++ + "\n );\n" + ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" + ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" + else id + ,showString "}\n"] + +-- | Takes list of metrics (alignment, sizeof). +-- Returns (offsets, size of struct). +computeStructOffsets :: [(Int, Int)] -> ([Int], Int) +computeStructOffsets = go 0 0 + where + go off maxal [(al, sz)] = + ([off], align (max maxal al) (off + sz)) + go off maxal ((al, sz) : pairs@((al2,_):_)) = + first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs + go _ _ [] = ([], 0) + +-- | Assumes that this is called at the correct alignment. +serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r +serialise topty topval ptr off k = + -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls + case (topty, topval) of + (STNil, ()) -> k + (STPair a b, (x, y)) -> + serialise a x ptr off $ + serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k + (STEither a _, Left x) -> do + pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STEither _ b, Right y) -> do + pokeByteOff ptr off (1 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k + (STMaybe _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STMaybe t, Just x) -> do + pokeByteOff ptr off (1 :: Word8) + serialise t x ptr (off + alignmentSTy t) k + (STArr n t, Array sh vec) -> do + let eltsz = sizeofSTy t + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do + when debugRefc $ + hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr + pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh + + pokeByteOff @Word64 bufptr 0 (2 ^ 63) + + let loop i + | i == shapeSize sh = k + | otherwise = + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ + loop (i+1) + loop 0 + (STScal sty, x) -> case sty of + STI32 -> pokeByteOff ptr off (x :: Int32) >> k + STI64 -> pokeByteOff ptr off (x :: Int64) >> k + STF32 -> pokeByteOff ptr off (x :: Float) >> k + STF64 -> pokeByteOff ptr off (x :: Double) >> k + STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k + (STAccum{}, _) -> error "Cannot serialise accumulators" + +-- | Assumes that this is called at the correct alignment. +deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) +deserialise topty ptr off = + -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls + case topty of + STNil -> return () + STPair a b -> do + x <- deserialise a ptr off + y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a)) + return (x, y) + STEither a b -> do + tag <- peekByteOff @Word8 ptr off + if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) + then Left <$> deserialise a ptr (off + alignmentSTy topty) + else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" + STMaybe t -> do + tag <- peekByteOff @Word8 ptr off + if tag == 0 + then return Nothing + else Just <$> deserialise t ptr (off + alignmentSTy t) + STArr n t -> do + bufptr <- peekByteOff @(Ptr ()) ptr off + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 + when debugRefc $ + hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) + when (refc < 2 ^ 62) $ free bufptr + return arr + STScal sty -> case sty of + STI32 -> peekByteOff @Int32 ptr off + STI64 -> peekByteOff @Int64 ptr off + STF32 -> peekByteOff @Float ptr off + STF64 -> peekByteOff @Double ptr off + STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off + STAccum{} -> error "Cannot serialise accumulators" + +align :: Int -> Int -> Int +align a off = (off + a - 1) `div` a * a + +alignmentSTy :: STy t -> Int +alignmentSTy = fst . metricsSTy + +sizeofSTy :: STy t -> Int +sizeofSTy = snd . metricsSTy + +-- | Returns (alignment, sizeof) +metricsSTy :: STy t -> (Int, Int) +metricsSTy STNil = (1, 0) +metricsSTy (STPair a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, align (max a1 a2) (s1 + s2)) +metricsSTy (STEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STMaybe t) = + let (a, s) = metricsSTy t + in (a, a + s) -- the union after the tag byte is aligned +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) +metricsSTy (STScal sty) = case sty of + STI32 -> (4, 4) + STI64 -> (8, 8) + STF32 -> (4, 4) + STF64 -> (8, 8) + STBool -> (1, 1) -- compiled to uint8_t +metricsSTy (STAccum t) = metricsSTy (fromSMTy t) + +pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () +pokeShape ptr off = go . fromSNat + where + go :: Int -> Shape n -> IO () + go rank = \case + ShNil -> return () + sh `ShCons` n -> do + pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64) + go (rank - 1) sh + +peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) +peekShape ptr off = \case + SZ -> return ShNil + SS n -> ShCons <$> peekShape ptr off n + <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) + +compile' :: SList (Const String) env -> Ex env t -> CompM CExpr +compile' env = \case + EVar _ t i -> do + let Const var = slistIdx env i + incrementVarAlways "var" Increment t var + return $ CELit var + + ELet _ rhs body -> do + var <- compileAssign "" env rhs + rete <- compile' (Const var `SCons` env) body + incrementVarAlways "let" Decrement (typeOf rhs) var + return rete + + EPair _ a b -> do + name <- emitStruct (STPair (typeOf a) (typeOf b)) + e1 <- compile' env a + e2 <- compile' env b + return $ CEStruct name [("a", e1), ("b", e2)] + + EFst _ e -> do + let STPair _ t2 = typeOf e + e' <- compile' env e + case incrementVar "fst" Decrement t2 of + Nothing -> return $ CEProj e' "a" + Just f -> do var <- genName + emit $ SVarDecl True (repSTy (typeOf e)) var e' + f (var ++ ".b") + return $ CEProj (CELit var) "a" + + ESnd _ e -> do + let STPair t1 _ = typeOf e + e' <- compile' env e + case incrementVar "snd" Decrement t1 of + Nothing -> return $ CEProj e' "b" + Just f -> do var <- genName + emit $ SVarDecl True (repSTy (typeOf e)) var e' + f (var ++ ".a") + return $ CEProj (CELit var) "b" + + ENil _ -> do + name <- emitStruct STNil + return $ CEStruct name [] + + EInl _ t e -> do + name <- emitStruct (STEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "0"), ("l", e1)] + + EInr _ t e -> do + name <- emitStruct (STEither t (typeOf e)) + e2 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("r", e2)] + + ECase _ (EOp _ OIf e) a b -> do + e1 <- compile' env e + (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you + (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SIf e1 + (stmts2 <> pure (SAsg retvar e2)) + (stmts3 <> pure (SAsg retvar e3)) + return (CELit retvar) + + ECase _ e a b -> do + let STEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + -- I know those are not variable names, but it's fine, probably + (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b + ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 + <> stmtsRel1 + <> pure (SAsg retvar e2)) + (stmts3 + <> stmtsRel2 + <> pure (SAsg retvar e3)))) + return (CELit retvar) + + ENothing _ t -> do + name <- emitStruct (STMaybe t) + return $ CEStruct name [("tag", CELit "0")] + + EJust _ e -> do + name <- emitStruct (STMaybe (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("j", e1)] + + EMaybe _ a b e -> do + let STMaybe t = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b + ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 + <> pure (SAsg retvar e2)) + (stmts3 + <> stmtsRel + <> pure (SAsg retvar e3)))) + return (CELit retvar) + + ELNil _ t1 t2 -> do + name <- emitStruct (STLEither t1 t2) + return $ CEStruct name [("tag", CELit "0")] + + ELInl _ t e -> do + name <- emitStruct (STLEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("l", e1)] + + ELInr _ t e -> do + name <- emitStruct (STLEither t (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "2"), ("r", e1)] + + ELCase _ e a b c -> do + let STLEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b + (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c + ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 <> pure (SAsg retvar e2)) + (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) + (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) + (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) + return (CELit retvar) + + EConstArr _ n t (Array sh vec) -> do + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) + tldname <- genName' "carraybuf" + -- Give it a refcount of _half_ the size_t max, so that it can be + -- incremented and decremented at will and will "never" reach anything + -- where something happens + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) + + EBuild _ n esh efun -> do + shname <- compileAssign "sh" env esh + + arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname) + + idxargname <- genName' "ix" + (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun + + linivar <- genName' "li" + ivars <- replicateM (fromSNat n) (genName' "i") + emit $ SBlock $ + pure (SVarDecl False "size_t" linivar (CELit "0")) + <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) + | (ivar, dimidx) <- zip ivars [0::Int ..]] + (pure (SVarDecl True (repSTy (typeOf esh)) idxargname + (shapeTupFromLitVars n ivars)) + <> funstmts + <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval)) + + return (CELit arrname) + + -- TODO: actually generate decent code here + EMap _ e1 e2 -> do + let STArr n _ = typeOf e2 + compile' env $ + elet e2 $ + EBuild ext n (EShape ext (evar IZ)) $ + elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e1 + + EFold1Inner _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + + -- let vecwid = case commut of Commut -> 8 :: Int + -- Noncommut -> 1 + + x0name <- compileAssign "foldx0" env ex0 + arrname <- compileAssign "foldarr" env earr + + zeroRefcountCheck (typeOf earr) "fold1i" arrname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) + + resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + -- kvar <- if vecwid > 1 then genName' "k" else return "" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ + ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SAsg accvar funres)) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldx0" Decrement t x0name + incrementVarAlways "foldarr" Decrement (typeOf earr) arrname + + return (CELit resname) + + ESum1Inner _ e -> do + let STArr (SS n) t = typeOf e + argname <- compileAssign "sumarg" env e + + zeroRefcountCheck (typeOf e) "sum1i" argname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, like EFold1Inner. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + let vecwid = 8 :: Int + ivar <- genName' "i" + jvar <- genName' "j" + kvar <- genName' "k" + accvar <- genName' "tot" + let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList + -- we have ScalIsNumeric, so it has 0 and (+) in C + [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};" + ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $ + pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];" + ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))] + + incrementVarAlways "sum" Decrement (typeOf e) argname + + return (CELit resname) + + EUnit _ e -> do + e' <- compile' env e + let typ = STArr SZ (typeOf e) + strname <- emitStruct typ + name <- genName + emit $ SVarDecl True strname name (CEStruct strname + [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])]) + emit $ SAsg (name ++ ".buf->refc") (CELit "1") + emit $ SAsg (name ++ ".buf->xs[0]") e' + return (CELit name) + + EReplicate1Inner _ elen earg -> do + let STArr n t = typeOf earg + lenname <- compileAssign "replen" env elen + argname <- compileAssign "reparg" env earg + + zeroRefcountCheck (typeOf earg) "replicate1i" argname + + shszname <- genName' "shsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray "repl1i" Malloc "rep" (SS n) t + (Just (CEBinop (CELit shszname) "*" (CELit lenname))) + (compileArrShapeComponents n argname ++ [CELit lenname]) + + ivar <- genName' "i" + jvar <- genName' "j" + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]") + (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]")) + + incrementVarAlways "repl1i" Decrement (typeOf earg) argname + + return (CELit resname) + + EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e + + EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + -- TODO: actually generate decent code here + EZip _ e1 e2 -> do + let STArr n _ = typeOf e1 + compile' env $ + elet e1 $ + elet (weakenExpr WSink e2) $ + EBuild ext n (EShape ext (evar (IS IZ))) $ + EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + + EConst _ t x -> return $ CELit $ compileScal True t x + + EIdx0 _ e -> do + let STArr _ t = typeOf e + arrname <- compileAssign "" env e + zeroRefcountCheck (typeOf e) "idx0" arrname + name <- genName + emit $ SVarDecl True (repSTy t) name + (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) + incrementVarAlways "idx0" Decrement (STArr SZ t) arrname + return (CELit name) + + -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) + + EIdx _ earr eidx -> do + let STArr n t = typeOf earr + arrname <- compileAssign "ixarr" env earr + zeroRefcountCheck (typeOf earr) "idx" arrname + idxname <- if fromSNat n > 0 -- prevent an unused-varable warning + then compileAssign "ixix" env eidx + else return "" -- won't be used in this case + + when emitChecks $ + forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> + emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ + arrname ++ ".buf); return false;") + mempty + + resname <- genName' "ixres" + emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname)) + incrementVarAlways "idxelt" Increment t resname + incrementVarAlways "idx" Decrement (STArr n t) arrname + return (CELit resname) + + EShape _ e -> do + let STArr n _ = typeOf e + t = tTup (sreplicate n tIx) + _ <- emitStruct t + name <- compileAssign "" env e + zeroRefcountCheck (typeOf e) "shape" name + resname <- genName + emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) + incrementVarAlways "shape" Decrement (typeOf e) name + return (CELit resname) + + EOp _ op (EPair _ e1 e2) -> do + e1' <- compile' env e1 + e2' <- compile' env e2 + compileOpPair op e1' e2' + + EOp _ op e -> do + e' <- compile' env e + compileOpGeneral op e' + + ECustom _ _ _ _ earg _ _ e1 e2 -> do + name1 <- compileAssign "" env e1 + name2 <- compileAssign "" env e2 + case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of + (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg + (mfun1, mfun2) -> do + name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg + maybe (return ()) ($ name1) mfun1 + maybe (return ()) ($ name2) mfun2 + return (CELit name) + + ERecompute _ e -> compile' env e + + EWith _ t e1 e2 -> do + actyname <- emitStruct (STAccum t) + name1 <- compileAssign "" env e1 + + zeroRefcountCheck (typeOf e1) "with" name1 + + emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" + mcopy <- copyForWriting t name1 + accname <- genName' "accum" + emit $ SVarDecl False actyname accname + (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) + emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." + + e2' <- compile' (Const accname `SCons` env) e2 + + resname <- genName' "acret" + emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) + emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" + + rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) + return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] + + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do + let -- Add a value (s) into an existing accumulation value (d). If a sparse + -- component of d is encountered, s is copied there. + add :: SMTy a -> String -> String -> CompM () + add SMTNil _ _ = return () + add (SMTPair t1 t2) d s = do + add t1 (d++".a") (s++".a") + add t2 (d++".b") (s++".b") + add (SMTLEither t1 t2) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s + ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") + ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + ((if emitChecks + then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) + "&&" + (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ + "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ + "return false;") + mempty) + else mempty) + -- note: s may have tag 0 + <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + stmts1 + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) + stmts2 mempty)))) + add (SMTMaybe t1) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s + ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) + add (SMTArr n t1) d s = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ [0 .. fromSNat n - 1] $ \j -> do + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) + "!=" + (CELit (d ++ ".sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ + "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ + d ++ ".buf" ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + ", " ++ s ++ ".buf" ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + "); " ++ + "return false;") + mempty + + shsizename <- genName' "acshsz" + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) + ivar <- genName' "i" + ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1 + add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + + let -- | Dereference an accumulation value and add a given value to that + -- position. Sparse components encountered along the way are + -- initialised before proceeding downwards. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () + accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend + + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend + + accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend + + accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do + let a .||. b = CEBinop a "||" b + emit $ SIf (CEBinop ixcomp "<" (CELit "0") + .||. + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ + "); " ++ + "return false;") + mempty + + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend + + nameidx <- compileAssign "acidx" env eidx + nameval <- compileAssign "acval" env eval + nameacc <- compileAssign "acac" env eacc + + emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" + accumRef t prj (nameacc++".buf->ac") nameidx nameval + emit $ SVerbatim $ "// compile EAccum end" + + incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval + + return $ CEStruct (repSTy STNil) [] + + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + + EError _ t s -> do + let padleft len c s' = replicate (len - length s) c ++ s' + escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] + | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") + | otherwise -> [c] + emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;" + case t of + STScal _ -> return (CELit "0") + _ -> do + name <- emitStruct t + return $ CEStruct name [] + + EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + + EIdx1{} -> error "Compile: not implemented: EIdx1" + +compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String +compileAssign prefix env e = do + e' <- compile' env e + case e' of + CELit name -> return name + _ -> do + name <- genName' prefix + emit $ SVarDecl True (repSTy (typeOf e)) name e' + return name + +data Increment = Increment | Decrement + deriving (Show) + +-- | Increment reference counts in the components of the given variable. +incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ()) +incrementVar marker inc ty = + let tree = makeArrayTree ty + in case tree of ATNoop -> Nothing + _ -> Just $ \var -> incrementVar' marker inc var tree + +incrementVarAlways :: String -> Increment -> STy a -> String -> CompM () +incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty) + +data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array) + | ATNoop -- ^ don't do anything here + | ATProj String ArrayTree -- ^ descend one field deeper + | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 + | ATBoth ArrayTree ArrayTree -- ^ do both these paths + +smartATProj :: String -> ArrayTree -> ArrayTree +smartATProj _ ATNoop = ATNoop +smartATProj field t = ATProj field t + +smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree +smartATCondTag ATNoop ATNoop = ATNoop +smartATCondTag t t' = ATCondTag t t' + +smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree +smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop +smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 + +smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree +smartATBoth ATNoop t = t +smartATBoth t ATNoop = t +smartATBoth t t' = ATBoth t t' + +makeArrayTree :: STy a -> ArrayTree +makeArrayTree STNil = ATNoop +makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) + (smartATProj "b" (makeArrayTree b)) +makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) +makeArrayTree (STArr n t) = ATArray (Some n) (Some t) +makeArrayTree (STScal _) = ATNoop +makeArrayTree (STAccum _) = ATNoop + +incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () +incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = + case inc of + Increment -> do + emit $ SVerbatim (path ++ ".buf->refc++;") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" + Decrement -> do + case incrementVar (marker++".elt") Decrement eltty of + Nothing -> + if debugRefc + then do + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" + else do + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);" + Just f -> do + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + shszvar <- genName' "frshsz" + ivar <- genName' "i" + ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") + emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0")) + (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path) + ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $ + eltDecrStmts + ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"]) + mempty +incrementVar' _ _ _ ATNoop = pure () +incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t +incrementVar' marker inc path (ATCondTag t1 t2) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 +incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) + stmts2 + (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) + stmts3 + stmts1)) +incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 + +toLinearIdx :: SNat n -> String -> String -> CExpr +toLinearIdx SZ _ _ = CELit "0" +toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") +toLinearIdx (SS n) arrvar idxvar = + CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) + "+" (CELit (idxvar ++ ".b")) + +-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr +-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] +-- fromLinearIdx (SS n) arrvar idxvar = do +-- name <- genName +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) +-- _ + +data AllocMethod = Malloc | Calloc + deriving (Show) + +-- | The shape must have the outer dimension at the head (and the inner dimension on the right). +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray marker method nameBase rank eltty mshsz shape = do + when (length shape /= fromSNat rank) $ + error "allocArray: shape does not match rank" + let arrty = STArr rank eltty + strname <- emitStruct arrty + arrname <- genName' nameBase + shsz <- case mshsz of + Just e -> return e + Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape) + let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8))) + "+" + (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) + emit $ SVarDecl True strname arrname $ CEStruct strname + [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] + emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" + return arrname + +compileShapeQuery :: SNat n -> String -> CExpr +compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] +compileShapeQuery (SS n) var = + CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) + [("a", compileShapeQuery n var) + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeSize :: SNat n -> String -> CExpr +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeComponents :: SNat n -> String -> [CExpr] +compileArrShapeComponents n var = + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + +indexTupleComponents :: SNat n -> String -> [CExpr] +indexTupleComponents = \n var -> map CELit (toList (go n var)) + where + go :: SNat n -> String -> Bag String + go SZ _ = mempty + go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b") + +-- | Takes variable names with the innermost dimension on the right. +shapeTupFromLitVars :: SNat n -> [String] -> CExpr +shapeTupFromLitVars = \n -> go n . reverse + where + -- takes variables with the innermost dimension at the _head_ + go :: SNat n -> [String] -> CExpr + go SZ [] = CEStruct (repSTy STNil) [] + go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] + go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" + +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + +compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr +compileOpGeneral op e1 = do + let unary cop = return @CompM $ CECall cop [e1] + let binary cop = do + name <- genName + emit $ SVarDecl True (repSTy (opt1 op)) name e1 + return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + ONeg _ -> unary "-" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + ONot -> unary "!" + OAnd -> binary "&&" + OOr -> binary "||" + OIf -> do + name <- emitStruct (STEither STNil STNil) + _ <- emitStruct STNil + return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) + (CEStruct name [("tag", CELit "1")]) + ORound64 -> unary "(int64_t)round" -- ew + OToFl64 -> unary "(double)" + ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 + OExp STF32 -> unary "expf" + OExp STF64 -> unary "exp" + OLog STF32 -> unary "logf" + OLog STF64 -> unary "log" + OIDiv _ -> binary "/" + OMod _ -> binary "%" + +compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr +compileOpPair op e1 e2 = do + let binary cop = return @CompM $ CEBinop e1 cop e2 + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + OAnd -> binary "&&" + OOr -> binary "||" + OIDiv _ -> binary "/" + OMod _ -> binary "%" + _ -> error "compileOpPair: got unary operator" + +-- | Bool: whether to ensure that the literal itself already has the appropriate type +compileScal :: Bool -> SScalTy t -> ScalRep t -> String +compileScal pedantic typ x = case typ of + STI32 -> (if pedantic then "(int32_t)" else "") ++ show x + STI64 -> (if pedantic then "(int64_t)" else "") ++ show x + STF32 -> show x ++ "f" + STF64 -> show x + STBool -> if x then "1" else "0" + +compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr +compileExtremum nameBase opName operator env e = do + let STArr (SS n) t = typeOf e + argname <- compileAssign (nameBase ++ "arg") env e + + zeroRefcountCheck (typeOf e) opName argname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" + + ivar <- genName' "i" + jvar <- genName' "j" + xvar <- genName + redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList + -- we have ScalIsNumeric, so it has 1 and (<) etc. in C + [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]")) + ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList + [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]")) + ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar) + ] + ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)] + + incrementVarAlways nameBase Decrement (typeOf e) argname + + return (CELit resname) + +-- | If this returns Nothing, there was nothing to copy because making a simple +-- value copy in C already makes it suitable to write to. +copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) +copyForWriting topty var = case topty of + SMTNil -> return Nothing + + SMTPair a b -> do + e1 <- copyForWriting a (var ++ ".a") + e2 <- copyForWriting b (var ++ ".b") + case (e1, e2) of + (Nothing, Nothing) -> return Nothing + _ -> return $ Just $ CEStruct toptyname + [("a", fromMaybe (CELit (var++".a")) e1) + ,("b", fromMaybe (CELit (var++".b")) e2)] + + SMTLEither a b -> do + (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") + (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") + case (e1, e2) of + (Nothing, Nothing) -> return Nothing + _ -> do + name <- genName + emit $ SVarDeclUninit toptyname name + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) + (stmts1 + <> pure (SAsg name (CEStruct toptyname + [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) + (stmts2 + <> pure (SAsg name (CEStruct toptyname + [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) + return (Just (CELit name)) + + SMTMaybe t -> do + (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") + case e1 of + Nothing -> return Nothing + Just e1' -> do + name <- genName + emit $ SVarDeclUninit toptyname name + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) + (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) + (stmts1 + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) + return (Just (CELit name)) + + -- If there are no nested arrays, we know that a refcount of 1 means that the + -- whole thing is owned. Nested arrays have their own refcount, so with + -- nesting we'd have to check the refcounts of all the nested arrays _too_; + -- let's not do that. Furthermore, no sub-arrays means that the whole thing + -- is flat, and we can just memcpy if necessary. + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do + name <- genName + shszname <- genName' "shsz" + emit $ SVarDeclUninit toptyname name + + when debugShapes $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + emit $ SVerbatim $ + "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + ");" + + emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) + (pure (SAsg name (CELit var))) + (let shbytes = fromSNat n * 8 + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) + totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + in BList + [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" + ,SAsg (name ++ ".buf->refc") (CELit "1") + ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ + printCExpr 0 databytes ");"]) + return (Just (CELit name)) + + SMTArr n t -> do + shszname <- genName' "shsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + + let shbytes = fromSNat n * 8 + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) + totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + + name <- genName + emit $ SVarDecl False toptyname name + (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" + emit $ SAsg (name ++ ".buf->refc") (CELit "1") + + -- put the arrays in variables to cut short the not-quite-var chain + dstvar <- genName' "cpydst" + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) + srcvar <- genName' "cpysrc" + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) + + ivar <- genName' "i" + + (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]") + let cpye' = case cpye of + Just e -> e + Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + cpystmts + <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') + + return (Just (CELit name)) + + SMTScal _ -> return Nothing + + where + toptyname = repSTy (fromSMTy topty) + +zeroRefcountCheck :: STy t -> String -> String -> CompM () +zeroRefcountCheck toptyp opname topvar = + when emitChecks $ do + mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) + case mstmts of + Nothing -> return () + Just stmts -> forM_ stmts emit + where + -- | If this returns 'Nothing', no statements need to be generated for this type. + go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) + go STNil _ = empty + go (STPair a b) path = do + (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) + return (s1 <> s2) + go (STEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) + go (STMaybe a) path = do + ss <- go a (path++".j") + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty + go (STArr n a) path = do + ivar <- genName' "i" + ss <- go a (path++".buf->xs["++ivar++"]") + shszname <- genName' "shsz" + let s1 = SVerbatim $ + "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ + "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++ + "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }" + let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) + let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss + return (BList [s1, s2, s3]) + go STScal{} _ = empty + go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + + combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) + combine (MaybeT a) (MaybeT b) = MaybeT $ do + x <- a + y <- b + return $ case (x, y) of + (Nothing, Nothing) -> Nothing + (Just x', Nothing) -> Just (x', mempty) + (Nothing, Just y') -> Just (mempty, y') + (Just x', Just y') -> Just (x', y') + +{-# NOINLINE uniqueIdGenRef #-} +uniqueIdGenRef :: IORef Int +uniqueIdGenRef = unsafePerformIO $ newIORef 1 + +compose :: Foldable t => t (a -> a) -> a -> a +compose = foldr (.) id + +showPtr :: Ptr a -> String +showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" + +-- | Type-restricted. +(^) :: Num a => a -> Int -> a +(^) = (Prelude.^) + +foldl0' :: (a -> a -> a) -> a -> [a] -> a +foldl0' _ x [] = x +foldl0' f _ l = foldl1' f l |
