aboutsummaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs1711
1 files changed, 0 insertions, 1711 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
deleted file mode 100644
index 6ba3a39..0000000
--- a/src/Compile.hs
+++ /dev/null
@@ -1,1711 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE MagicHash #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-module Compile (compile) where
-
-import Control.Applicative (empty)
-import Control.Monad (forM_, when, replicateM)
-import Control.Monad.Trans.Class (lift)
-import Control.Monad.Trans.Maybe
-import Control.Monad.Trans.State.Strict
-import Control.Monad.Trans.Writer.CPS
-import Data.Bifunctor (first)
-import Data.Char (ord)
-import Data.Foldable (toList)
-import Data.Functor.Const
-import qualified Data.Functor.Product as Product
-import Data.Functor.Product (Product)
-import Data.IORef
-import Data.List (foldl1', intersperse, intercalate)
-import qualified Data.Map.Strict as Map
-import Data.Maybe (fromMaybe)
-import qualified Data.Set as Set
-import Data.Set (Set)
-import Data.Some
-import qualified Data.Vector as V
-import Foreign
-import GHC.Exts (int2Word#, addr2Int#)
-import GHC.Num (integerFromWord#)
-import GHC.Ptr (Ptr(..))
-import Numeric (showHex)
-import System.IO (hPutStrLn, stderr)
-import System.IO.Error (mkIOError, userErrorType)
-import System.IO.Unsafe (unsafePerformIO)
-
-import Prelude hiding ((^))
-import qualified Prelude
-
-import Array
-import AST
-import AST.Pretty (ppSTy, ppExpr)
-import Compile.Exec
-import Data
-import Interpreter.Rep
-import qualified Util.IdGen as IdGen
-
-
--- In shape and index arrays, the innermost dimension is on the right (last index).
-
--- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
-
-
--- | Print the compiled AST
-debugPrintAST :: Bool; debugPrintAST = toEnum 0
--- | Print the generated C source
-debugCSource :: Bool; debugCSource = toEnum 0
--- | Print extra stuff about reference counts of arrays
-debugRefc :: Bool; debugRefc = toEnum 0
--- | Print some shape-related information
-debugShapes :: Bool; debugShapes = toEnum 0
--- | Print information on allocation
-debugAllocs :: Bool; debugAllocs = toEnum 0
--- | Emit extra C code that checks stuff
-emitChecks :: Bool; emitChecks = toEnum 0
-
-compile :: SList STy env -> Ex env t
- -> IO (SList Value env -> IO (Rep t))
-compile = \env expr -> do
- codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
-
- let (source, offsets) = compileToString codeID env expr
- when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
- when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
- lib <- buildKernel source ["kernel"]
-
- let result_type = typeOf expr
- result_size = sizeofSTy result_type
-
- return $ \val -> do
- allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
- let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
- serialiseArguments args ptr $ do
- callKernelFun "kernel" lib ptr
- ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
- when (ok /= 1) $
- ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
- deserialise result_type ptr (koResultOffset offsets)
- where
- serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
- serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
- serialise t arg ptr off $
- serialiseArguments args ptr k
- serialiseArguments _ _ k = k
-
-
-data StructDecl = StructDecl
- String -- ^ name
- String -- ^ contents
- String -- ^ comment
- deriving (Show)
-
-data Stmt
- = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side
- | SVarDeclUninit String String -- ^ type, variable name (no initialiser)
- | SAsg String CExpr -- ^ variable name, right-hand side
- | SBlock (Bag Stmt)
- | SIf CExpr (Bag Stmt) (Bag Stmt)
- | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>}
- | SVerbatim String -- ^ no implicit ';', just printed as-is
- deriving (Show)
-
-data CExpr
- = CELit String -- ^ inserted as-is, assumed no parentheses needed
- | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
- | CEProj CExpr String -- ^ field projection: expr.field
- | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field
- | CEAddrOf CExpr -- ^ &expr
- | CEIndex CExpr CExpr -- ^ expr[expr]
- | CECall String [CExpr] -- ^ function(arg1, ..., argn)
- | CEBinop CExpr String CExpr -- ^ expr + expr
- | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
- | CECast String CExpr -- ^ (<type)<expr>
- deriving (Show)
-
-printStructDecl :: StructDecl -> ShowS
-printStructDecl (StructDecl name contents comment) =
- showString "typedef struct { " . showString contents . showString " } " . showString name
- . showString ";" . (if null comment then id else showString (" // " ++ comment))
-
-printStmt :: Int -> Stmt -> ShowS
-printStmt indent = \case
- SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
- SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
- SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
- SBlock stmts
- | null stmts -> showString "{}"
- | otherwise ->
- showString "{"
- . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
- . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
- SIf cond b1 b2 ->
- showString "if (" . printCExpr 0 cond . showString ") "
- . printStmt indent (SBlock b1)
- . (if null b2 then id else showString " else " . printStmt indent (SBlock b2))
- SLoop typ name e1 e2 stmts ->
- showString ("for (" ++ typ ++ " " ++ name ++ " = ")
- . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2
- . showString ("; " ++ name ++ "++) ")
- . printStmt indent (SBlock stmts)
- SVerbatim s -> showString s
-
--- d values:
--- * 0: top level
--- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)
--- * 2-...: various operators (see precTable)
--- * 80: address-of operator (&)
--- * 98: inside unknown operator
--- * 99: left of a field projection
--- Unlisted operators are conservatively written with full parentheses.
-printCExpr :: Int -> CExpr -> ShowS
-printCExpr d = \case
- CELit s -> showString s
- CEStruct name pairs ->
- showParen (d >= 99) $
- showString ("(" ++ name ++ "){")
- . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e
- | (n, e) <- pairs])
- . showString "}"
- CEProj e name -> printCExpr 99 e . showString ("." ++ name)
- CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name)
- CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
- CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"
- CECall n es ->
- showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
- CEBinop e1 n e2 ->
- let mprec = Map.lookup n precTable
- p = maybe (-1) fst mprec -- precedence of this operator
- (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments
- in showParen (d > p) $
- printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2
- CEIf e1 e2 e3 ->
- showParen (d > 0) $
- printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3
- CECast typ e ->
- showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e
- where
- precTable = Map.fromList
- [("||", (2, (2, 2)))
- ,("&&", (3, (3, 3)))
- ,("==", (4, (5, 5)))
- ,("!=", (4, (5, 5)))
- ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop
- ,(">", (5, (6, 6)))
- ,("<=", (5, (6, 6)))
- ,(">=", (5, (6, 6)))
- ,("+", (6, (6, 7)))
- ,("-", (6, (6, 7)))
- ,("*", (7, (7, 8)))
- ,("/", (7, (7, 8)))
- ,("%", (7, (7, 8)))]
-
-repSTy :: STy t -> String
-repSTy (STScal st) = case st of
- STI32 -> "int32_t"
- STI64 -> "int64_t"
- STF32 -> "float"
- STF64 -> "double"
- STBool -> "uint8_t"
-repSTy t = genStructName t
-
-genStructName :: STy t -> String
-genStructName = \t -> "ty_" ++ gen t where
- -- all tags start with a letter, so the array mangling is unambiguous.
- gen :: STy t -> String
- gen STNil = "n"
- gen (STPair a b) = 'P' : gen a ++ gen b
- gen (STEither a b) = 'E' : gen a ++ gen b
- gen (STMaybe t) = 'M' : gen t
- gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
- gen (STScal st) = case st of
- STI32 -> "i"
- STI64 -> "j"
- STF32 -> "f"
- STF64 -> "d"
- STBool -> "b"
- gen (STAccum t) = 'C' : gen (fromSMTy t)
- gen (STLEither a b) = 'L' : gen a ++ gen b
-
--- | This function generates the actual struct declarations for each of the
--- types in our language. It thus implicitly "documents" the layout of the
--- types in the C translation.
---
--- For accumulation it is important that for struct representations of monoid
--- types, the all-zero-bytes value corresponds to the zero value of that type.
-genStruct :: String -> STy t -> [StructDecl]
-genStruct name topty = case topty of
- STNil ->
- [StructDecl name "" com]
- STPair a b ->
- [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
- STEither a b -> -- 0 -> l, 1 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
- STMaybe t -> -- 0 -> nothing, 1 -> just
- [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
- STArr n t ->
- -- The buffer is trailed by a VLA for the actual array data.
- -- TODO: put shape in the main struct, not the buffer; it's constant, after all
- -- TODO: no buffer if n = 0
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
- STScal _ ->
- []
- STAccum t ->
- [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
- STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
- where
- com = ppSTy 0 topty
-
--- State: already-generated (skippable) struct names
--- Writer: the structs in declaration order
-genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()
-genStructs ty = do
- let name = genStructName ty
- seen <- lift $ gets (name `Set.member`)
-
- if seen
- then pure ()
- else do
- -- already mark this struct as generated now, so we don't generate it
- -- twice (unnecessary because no recursive types, but y'know)
- lift $ modify (Set.insert name)
-
- () <- case ty of
- STNil -> pure ()
- STPair a b -> genStructs a >> genStructs b
- STEither a b -> genStructs a >> genStructs b
- STMaybe t -> genStructs t
- STArr _ t -> genStructs t
- STScal _ -> pure ()
- STAccum t -> genStructs (fromSMTy t)
- STLEither a b -> genStructs a >> genStructs b
-
- tell (BList (genStruct name ty))
-
-genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty
-
-data CompState = CompState
- { csStructs :: Set (Some STy)
- , csTopLevelDecls :: Bag String
- , csStmts :: Bag Stmt
- , csNextId :: Int }
- deriving (Show)
-
-newtype CompM a = CompM (State CompState a)
- deriving newtype (Functor, Applicative, Monad)
-
-runCompM :: CompM a -> (a, CompState)
-runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
-
-class Monad m => MonadNameGen m where genId :: m Int
-instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
-instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
-instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
-
-genName' :: MonadNameGen m => String -> m String
-genName' "" = genName
-genName' prefix = (prefix ++) . show <$> genId
-
-genName :: MonadNameGen m => m String
-genName = genName' "x"
-
-onlyIdGen :: IdGen.IdGen a -> CompM a
-onlyIdGen m = CompM $ do
- i1 <- gets csNextId
- let (res, i2) = IdGen.runIdGen' i1 m
- modify (\s -> s { csNextId = i2 })
- return res
-
-emit :: Stmt -> CompM ()
-emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
-
-scope :: CompM a -> CompM (a, Bag Stmt)
-scope m = do
- stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
- res <- m
- innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
- return (res, innerStmts)
-
-emitStruct :: STy t -> CompM String
-emitStruct ty = CompM $ do
- modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
- return (genStructName ty)
-
-emitTLD :: String -> CompM ()
-emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
-
-nameEnv :: SList f env -> SList (Const String) env
-nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
-
-data KernelOffsets = KernelOffsets
- { koArgOffsets :: [Int] -- ^ the function arguments
- , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
- , koResultOffset :: Int -- ^ the function result
- }
-
-compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
-compileToString codeID env expr =
- let args = nameEnv env
- (res, s) = runCompM (compile' args expr)
- structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
-
- (arg_pairs, arg_metrics) =
- unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
- (slistZip env args))
- (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
- result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
-
- offsets = KernelOffsets
- { koArgOffsets = arg_offsets
- , koOkResOffset = okres_offset
- , koResultOffset = result_offset }
- in (,offsets) . ($ "") $ compose
- [showString "#include <stdio.h>\n"
- ,showString "#include <stdint.h>\n"
- ,showString "#include <stdbool.h>\n"
- ,showString "#include <inttypes.h>\n"
- ,showString "#include <stdlib.h>\n"
- ,showString "#include <string.h>\n"
- ,showString "#include <math.h>\n\n"
- -- PRint-tag
- ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
-
- ,compose [printStructDecl sd . showString "\n" | sd <- structs]
- ,showString "\n"
-
- -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
- ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
- ,showString " void *ptr = malloc(n);\n"
- ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
- else id
- ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
- else id
- ,showString " return ptr;\n"
- ,showString "}\n"
- ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
- ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
- ,showString " void *ptr = calloc(n, 1);\n"
- ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
- else id
- ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
- else id
- ,showString " return ptr;\n"
- ,showString "}\n"
- ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
- ,showString "static void free_instr(void *ptr) {\n"
- ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
- else id
- ,showString " free(ptr);\n"
- ,showString "}\n\n"
-
- ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
-
- ,showString $
- "static bool typed_kernel(" ++
- repSTy (typeOf expr) ++ " *output" ++
- concatMap (", " ++)
- (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
- ") {\n"
- ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
- ,showString " *output = " . printCExpr 0 res . showString ";\n"
- ,showString " return true;\n"
- ,showString "}\n\n"
-
- ,showString "void kernel(void *data) {\n"
- -- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
- ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
- else id
- ,showString $ " const bool success = typed_kernel(" ++
- "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
- concat (map (\((arg, typ), off) ->
- ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
- ++ " /* " ++ arg ++ " */")
- (zip arg_pairs arg_offsets)) ++
- "\n );\n"
- ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
- ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
- else id
- ,showString "}\n"]
-
--- | Takes list of metrics (alignment, sizeof).
--- Returns (offsets, size of struct).
-computeStructOffsets :: [(Int, Int)] -> ([Int], Int)
-computeStructOffsets = go 0 0
- where
- go off maxal [(al, sz)] =
- ([off], align (max maxal al) (off + sz))
- go off maxal ((al, sz) : pairs@((al2,_):_)) =
- first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs
- go _ _ [] = ([], 0)
-
--- | Assumes that this is called at the correct alignment.
-serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r
-serialise topty topval ptr off k =
- -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
- case (topty, topval) of
- (STNil, ()) -> k
- (STPair a b, (x, y)) ->
- serialise a x ptr off $
- serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
- (STEither a _, Left x) -> do
- pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STEither _ b, Right y) -> do
- pokeByteOff ptr off (1 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
- (STMaybe _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STMaybe t, Just x) -> do
- pokeByteOff ptr off (1 :: Word8)
- serialise t x ptr (off + alignmentSTy t) k
- (STArr n t, Array sh vec) -> do
- let eltsz = sizeofSTy t
- allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
- when debugRefc $
- hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
- pokeByteOff ptr off bufptr
-
- pokeShape bufptr 0 n sh
- pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63)
-
- let off1 = fromSNat n * 8 + 8
- loop i
- | i == shapeSize sh = k
- | otherwise =
- serialise t (vec V.! i) bufptr (off1 + i * eltsz) $
- loop (i+1)
- loop 0
- (STScal sty, x) -> case sty of
- STI32 -> pokeByteOff ptr off (x :: Int32) >> k
- STI64 -> pokeByteOff ptr off (x :: Int64) >> k
- STF32 -> pokeByteOff ptr off (x :: Float) >> k
- STF64 -> pokeByteOff ptr off (x :: Double) >> k
- STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
- (STAccum{}, _) -> error "Cannot serialise accumulators"
- (STLEither _ _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STLEither a _, Just (Left x)) -> do
- pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STLEither _ b, Just (Right y)) -> do
- pokeByteOff ptr off (2 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
-
--- | Assumes that this is called at the correct alignment.
-deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
-deserialise topty ptr off =
- -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
- case topty of
- STNil -> return ()
- STPair a b -> do
- x <- deserialise a ptr off
- y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a))
- return (x, y)
- STEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
- then Left <$> deserialise a ptr (off + alignmentSTy topty)
- else Right <$> deserialise b ptr (off + alignmentSTy topty)
- STMaybe t -> do
- tag <- peekByteOff @Word8 ptr off
- if tag == 0
- then return Nothing
- else Just <$> deserialise t ptr (off + alignmentSTy t)
- STArr n t -> do
- bufptr <- peekByteOff @(Ptr ()) ptr off
- sh <- peekShape bufptr 0 n
- refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
- when debugRefc $
- hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
- let off1 = 8 * fromSNat n + 8
- eltsz = sizeofSTy t
- arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
- when (refc < 2 ^ 62) $ free bufptr
- return arr
- STScal sty -> case sty of
- STI32 -> peekByteOff @Int32 ptr off
- STI64 -> peekByteOff @Int64 ptr off
- STF32 -> peekByteOff @Float ptr off
- STF64 -> peekByteOff @Double ptr off
- STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
- STAccum{} -> error "Cannot serialise accumulators"
- STLEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
- 0 -> return Nothing
- 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
- 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
- _ -> error "Invalid tag value"
-
-align :: Int -> Int -> Int
-align a off = (off + a - 1) `div` a * a
-
-alignmentSTy :: STy t -> Int
-alignmentSTy = fst . metricsSTy
-
-sizeofSTy :: STy t -> Int
-sizeofSTy = snd . metricsSTy
-
--- | Returns (alignment, sizeof)
-metricsSTy :: STy t -> (Int, Int)
-metricsSTy STNil = (1, 0)
-metricsSTy (STPair a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, align (max a1 a2) (s1 + s2))
-metricsSTy (STEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
-metricsSTy (STMaybe t) =
- let (a, s) = metricsSTy t
- in (a, a + s) -- the union after the tag byte is aligned
-metricsSTy (STArr _ _) = (8, 8)
-metricsSTy (STScal sty) = case sty of
- STI32 -> (4, 4)
- STI64 -> (8, 8)
- STF32 -> (4, 4)
- STF64 -> (8, 8)
- STBool -> (1, 1) -- compiled to uint8_t
-metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
-metricsSTy (STLEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
-
-pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
-pokeShape ptr off = go . fromSNat
- where
- go :: Int -> Shape n -> IO ()
- go rank = \case
- ShNil -> return ()
- sh `ShCons` n -> do
- pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64)
- go (rank - 1) sh
-
-peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n)
-peekShape ptr off = \case
- SZ -> return ShNil
- SS n -> ShCons <$> peekShape ptr off n
- <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8))
-
-compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
-compile' env = \case
- EVar _ t i -> do
- let Const var = slistIdx env i
- incrementVarAlways "var" Increment t var
- return $ CELit var
-
- ELet _ rhs body -> do
- var <- compileAssign "" env rhs
- rete <- compile' (Const var `SCons` env) body
- incrementVarAlways "let" Decrement (typeOf rhs) var
- return rete
-
- EPair _ a b -> do
- name <- emitStruct (STPair (typeOf a) (typeOf b))
- e1 <- compile' env a
- e2 <- compile' env b
- return $ CEStruct name [("a", e1), ("b", e2)]
-
- EFst _ e -> do
- let STPair _ t2 = typeOf e
- e' <- compile' env e
- case incrementVar "fst" Decrement t2 of
- Nothing -> return $ CEProj e' "a"
- Just f -> do var <- genName
- emit $ SVarDecl True (repSTy (typeOf e)) var e'
- f (var ++ ".b")
- return $ CEProj (CELit var) "a"
-
- ESnd _ e -> do
- let STPair t1 _ = typeOf e
- e' <- compile' env e
- case incrementVar "snd" Decrement t1 of
- Nothing -> return $ CEProj e' "b"
- Just f -> do var <- genName
- emit $ SVarDecl True (repSTy (typeOf e)) var e'
- f (var ++ ".a")
- return $ CEProj (CELit var) "b"
-
- ENil _ -> do
- name <- emitStruct STNil
- return $ CEStruct name []
-
- EInl _ t e -> do
- name <- emitStruct (STEither (typeOf e) t)
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "0"), ("l", e1)]
-
- EInr _ t e -> do
- name <- emitStruct (STEither t (typeOf e))
- e2 <- compile' env e
- return $ CEStruct name [("tag", CELit "1"), ("r", e2)]
-
- ECase _ (EOp _ OIf e) a b -> do
- e1 <- compile' env e
- (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you
- (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SIf e1
- (stmts2 <> pure (SAsg retvar e2))
- (stmts3 <> pure (SAsg retvar e3))
- return (CELit retvar)
-
- ECase _ e a b -> do
- let STEither t1 t2 = typeOf e
- e1 <- compile' env e
- var <- genName
- -- I know those are not variable names, but it's fine, probably
- (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a
- (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b
- ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l")
- ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r")
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
- <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2
- <> stmtsRel1
- <> pure (SAsg retvar e2))
- (stmts3
- <> stmtsRel2
- <> pure (SAsg retvar e3))))
- return (CELit retvar)
-
- ENothing _ t -> do
- name <- emitStruct (STMaybe t)
- return $ CEStruct name [("tag", CELit "0")]
-
- EJust _ e -> do
- name <- emitStruct (STMaybe (typeOf e))
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "1"), ("j", e1)]
-
- EMaybe _ a b e -> do
- let STMaybe t = typeOf e
- e1 <- compile' env e
- var <- genName
- (e2, stmts2) <- scope $ compile' env a
- (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b
- ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j")
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
- <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2
- <> pure (SAsg retvar e2))
- (stmts3
- <> stmtsRel
- <> pure (SAsg retvar e3))))
- return (CELit retvar)
-
- ELNil _ t1 t2 -> do
- name <- emitStruct (STLEither t1 t2)
- return $ CEStruct name [("tag", CELit "0")]
-
- ELInl _ t e -> do
- name <- emitStruct (STLEither (typeOf e) t)
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
-
- ELInr _ t e -> do
- name <- emitStruct (STLEither t (typeOf e))
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
-
- ELCase _ e a b c -> do
- let STLEither t1 t2 = typeOf e
- e1 <- compile' env e
- var <- genName
- (e2, stmts2) <- scope $ compile' env a
- (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
- (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
- ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
- ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
- <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2 <> pure (SAsg retvar e2))
- (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
- (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
- (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
- return (CELit retvar)
-
- EConstArr _ n t (Array sh vec) -> do
- strname <- emitStruct (STArr n (STScal t))
- tldname <- genName' "carraybuf"
- -- Give it a refcount of _half_ the size_t max, so that it can be
- -- incremented and decremented at will and will "never" reach anything
- -- where something happens
- emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++
- "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++
- ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
- return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
-
- EBuild _ n esh efun -> do
- shname <- compileAssign "sh" env esh
-
- arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
-
- idxargname <- genName' "ix"
- (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
-
- linivar <- genName' "li"
- ivars <- replicateM (fromSNat n) (genName' "i")
- emit $ SBlock $
- pure (SVarDecl False "size_t" linivar (CELit "0"))
- <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
- (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx))))
- | (ivar, dimidx) <- zip ivars [0::Int ..]]
- (pure (SVarDecl True (repSTy (typeOf esh)) idxargname
- (shapeTupFromLitVars n ivars))
- <> funstmts
- <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval))
-
- return (CELit arrname)
-
- EFold1Inner _ commut efun ex0 earr -> do
- let STArr (SS n) t = typeOf earr
-
- -- let vecwid = case commut of Commut -> 8 :: Int
- -- Noncommut -> 1
-
- x0name <- compileAssign "foldx0" env ex0
- arrname <- compileAssign "foldarr" env earr
-
- zeroRefcountCheck (typeOf earr) "fold1i" arrname
-
- shszname <- genName' "shsz"
- -- This n is one less than the shape of the thing we're querying, which is
- -- unexpected. But it's exactly what we want, so we do it anyway.
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
-
- resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
-
- ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
-
- ivar <- genName' "i"
- jvar <- genName' "j"
- -- kvar <- if vecwid > 1 then genName' "k" else return ""
-
- accvar <- genName' "tot"
- let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
- ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
- (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
- ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
-
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- pure (SVarDecl False (repSTy t) accvar (CELit x0name))
- <> x0incrStmts -- we're copying x0 here
- <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- -- The combination function will consume the array element
- -- and the accumulator. The accumulator is replaced by
- -- what comes out of the function anyway, so that's
- -- fine, but we do need to increment the array element.
- arreltIncrStmts
- <> funStmts
- <> pure (SAsg accvar funres))
- <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
-
- incrementVarAlways "foldx0" Decrement t x0name
- incrementVarAlways "foldarr" Decrement (typeOf earr) arrname
-
- return (CELit resname)
-
- ESum1Inner _ e -> do
- let STArr (SS n) t = typeOf e
- argname <- compileAssign "sumarg" env e
-
- zeroRefcountCheck (typeOf e) "sum1i" argname
-
- shszname <- genName' "shsz"
- -- This n is one less than the shape of the thing we're querying, like EFold1Inner.
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
-
- resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
-
- let vecwid = 8 :: Int
- ivar <- genName' "i"
- jvar <- genName' "j"
- kvar <- genName' "k"
- accvar <- genName' "tot"
- let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
- -- we have ScalIsNumeric, so it has 0 and (+) in C
- [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};"
- ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $
- pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
- pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];"
- ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $
- pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];"
- ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
- pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];"
- ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))]
-
- incrementVarAlways "sum" Decrement (typeOf e) argname
-
- return (CELit resname)
-
- EUnit _ e -> do
- e' <- compile' env e
- let typ = STArr SZ (typeOf e)
- strname <- emitStruct typ
- name <- genName
- emit $ SVarDecl True strname name (CEStruct strname
- [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])])
- emit $ SAsg (name ++ ".buf->refc") (CELit "1")
- emit $ SAsg (name ++ ".buf->xs[0]") e'
- return (CELit name)
-
- EReplicate1Inner _ elen earg -> do
- let STArr n t = typeOf earg
- lenname <- compileAssign "replen" env elen
- argname <- compileAssign "reparg" env earg
-
- zeroRefcountCheck (typeOf earg) "replicate1i" argname
-
- shszname <- genName' "shsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
-
- resname <- allocArray "repl1i" Malloc "rep" (SS n) t
- (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
- (compileArrShapeComponents n argname ++ [CELit lenname])
-
- ivar <- genName' "i"
- jvar <- genName' "j"
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]")
- (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]"))
-
- incrementVarAlways "repl1i" Decrement (typeOf earg) argname
-
- return (CELit resname)
-
- EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e
-
- EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
-
- EConst _ t x -> return $ CELit $ compileScal True t x
-
- EIdx0 _ e -> do
- let STArr _ t = typeOf e
- arrname <- compileAssign "" env e
- zeroRefcountCheck (typeOf e) "idx0" arrname
- name <- genName
- emit $ SVarDecl True (repSTy t) name
- (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
- incrementVarAlways "idx0" Decrement (STArr SZ t) arrname
- return (CELit name)
-
- -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
-
- EIdx _ earr eidx -> do
- let STArr n t = typeOf earr
- arrname <- compileAssign "ixarr" env earr
- zeroRefcountCheck (typeOf earr) "idx" arrname
- idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
- then compileAssign "ixix" env eidx
- else return "" -- won't be used in this case
-
- when emitChecks $
- forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
- emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
- (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
- arrname ++ ".buf); return false;")
- mempty
-
- resname <- genName' "ixres"
- emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname))
- incrementVarAlways "idxelt" Increment t resname
- incrementVarAlways "idx" Decrement (STArr n t) arrname
- return (CELit resname)
-
- EShape _ e -> do
- let STArr n _ = typeOf e
- t = tTup (sreplicate n tIx)
- _ <- emitStruct t
- name <- compileAssign "" env e
- zeroRefcountCheck (typeOf e) "shape" name
- resname <- genName
- emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
- incrementVarAlways "shape" Decrement (typeOf e) name
- return (CELit resname)
-
- EOp _ op (EPair _ e1 e2) -> do
- e1' <- compile' env e1
- e2' <- compile' env e2
- compileOpPair op e1' e2'
-
- EOp _ op e -> do
- e' <- compile' env e
- compileOpGeneral op e'
-
- ECustom _ _ _ _ earg _ _ e1 e2 -> do
- name1 <- compileAssign "" env e1
- name2 <- compileAssign "" env e2
- case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of
- (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
- (mfun1, mfun2) -> do
- name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg
- maybe (return ()) ($ name1) mfun1
- maybe (return ()) ($ name2) mfun2
- return (CELit name)
-
- EWith _ t e1 e2 -> do
- actyname <- emitStruct (STAccum t)
- name1 <- compileAssign "" env e1
-
- zeroRefcountCheck (typeOf e1) "with" name1
-
- emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
- mcopy <- copyForWriting t name1
- accname <- genName' "accum"
- emit $ SVarDecl False actyname accname
- (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
- emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
- emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
-
- e2' <- compile' (Const accname `SCons` env) e2
-
- resname <- genName' "acret"
- emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
- emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
-
- rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
- return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
-
- EAccum _ t prj eidx eval eacc -> do
- let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
- -- full zero array with the given zero info (for the type SMTArr n t1).
- initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
- initZeroArray n t1 v vzi = do
- shszname <- genName' "inacshsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi)
- newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi)
- emit $ SAsg v (CELit newarrName)
- forM_ (initZeroFromMemset t1) $ \f1 -> do
- ivar <- genName' "i"
- ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts
-
- -- If something needs to be done to properly initialise this type to
- -- zero after memory has already been initialised to all-zero bytes,
- -- returns an action that does so.
- -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ())
- initZeroFromMemset SMTNil = Nothing
- initZeroFromMemset (SMTPair t1 t2) =
- case (initZeroFromMemset t1, initZeroFromMemset t2) of
- (Nothing, Nothing) -> Nothing
- (mf1, mf2) -> Just $ \v vzi -> do
- forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a")
- forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b")
- initZeroFromMemset SMTLEither{} = Nothing
- initZeroFromMemset SMTMaybe{} = Nothing
- initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
- initZeroFromMemset SMTScal{} = Nothing
-
- let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroZI :: SMTy a -> String -> String -> CompM ()
- initZeroZI SMTNil _ _ = return ()
- initZeroZI (SMTPair t1 t2) v vzi = do
- initZeroZI t1 (v++".a") (vzi++".a")
- initZeroZI t2 (v++".b") (vzi++".b")
- initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZeroZI (SMTScal sty) v _ = case sty of
- STI32 -> emit $ SAsg v (CELit "0")
- STI64 -> emit $ SAsg v (CELit "0l")
- STF32 -> emit $ SAsg v (CELit "0.0f")
- STF64 -> emit $ SAsg v (CELit "0.0")
-
- let -- Initialise an uninitialised accumulation value, potentially already
- -- with the addend, potentially to zero depending on the nature of the
- -- projection.
- -- 1. If the projection indexes only through dense monoids before
- -- reaching SAPHere, the thing cannot be initialised to zero with
- -- only an AcIdx; it would need to model a zero after the addend,
- -- which is stupid and redundant. In this case, we return Left:
- -- (accumulation value) (AcIdx value) (addend value).
- -- The addend is copied, not consumed. (We can't reliably _always_
- -- consume it, so it's not worth trying to do it sometimes.)
- -- 2. Otherwise, a sparse monoid is found along the way, and we can
- -- initalise the dense prefix of the path to zero by setting the
- -- indexed-through sparse value to a sparse zero. Afterwards, the
- -- main recursion can proceed further. In this case, we return
- -- Right: (accumulation value) (AcIdx value)
- -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
- initZeroChunk :: SMTy a -> SAcPrj p a b
- -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
- (String -> String -> CompM ()) -- zero initialisation of sparse chunk
- initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
- -- reached target before the first sparse constructor
- (t1 , SAPHere ) -> Left $ \v _ addend -> do
- incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
- emit $ SAsg v (CELit addend)
- -- sparse types
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- -- dense types
- (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- f (v++".a") (i++".a")
- initZeroZI t2 (v++".b") (i++".b")
- (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
- initZeroZI t1 (v++".a") (i++".a")
- f (v++".b") (i++".b")
- (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- initZeroArray n t1 v (i++".a.b")
- linidxvar <- genName' "li"
- emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
- f (v++".buf->xs["++linidxvar++"]") (i++".b")
- where
- applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
- applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
-
- let -- Add a value (s) into an existing accumulation value (d). If a sparse
- -- component of d is encountered, s is copied there.
- add :: SMTy a -> String -> String -> CompM ()
- add SMTNil _ _ = return ()
- add (SMTPair t1 t2) d s = do
- add t1 (d++".a") (s++".a")
- add t2 (d++".b") (s++".b")
- add (SMTLEither t1 t2) d s = do
- ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
- ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
- ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s))
- <> srcIncrStmts)
- ((if emitChecks
- then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
- "&&"
- (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
- "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
- "return false;")
- mempty)
- else mempty)
- -- note: s may have tag 0
- <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- stmts1
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
- stmts2 mempty))))
- add (SMTMaybe t1) d s = do
- ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
- ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s))
- <> srcIncrStmts)
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
- add (SMTArr n t1) d s = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ [0 .. fromSNat n - 1] $ \j -> do
- emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]"))
- "!="
- (CELit (d ++ ".buf->sh[" ++ show j ++ "]")))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
- "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
- d ++ ".buf" ++
- concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- ", " ++ s ++ ".buf" ++
- concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- "); " ++
- "return false;")
- mempty
-
- shsizename <- genName' "acshsz"
- emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
- ivar <- genName' "i"
- ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
- stmts1
- add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
-
- let -- | Dereference an accumulation value and add a given value to that
- -- position. Sparse components encountered along the way are
- -- initialised before proceeding downwards.
- -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
- accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
- accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
-
- accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
- accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
-
- accumRef (SMTLEither ta tb) prj0 v i addend = do
- let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
- SAPRight prj' -> initZeroChunk tb prj'
- subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
- tagval = case prj0 of SAPLeft{} -> "1"
- SAPRight{} -> "2"
- ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
- SAPRight prj' -> accumRef tb prj' subv i addend
- case chunkres of
- Left densef -> do
- ((), stmtsSet) <- scope $ densef subv i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
- stmtsAdd -- TODO: emit check for consistency of tags?
- Right sparsef -> do
- ((), stmtsInit) <- scope $ sparsef subv i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
- forM_ stmtsAdd emit
-
- accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
- case initZeroChunk tj prj' of
- Left densef -> do
- ((), stmtsSet1) <- scope $ densef (v++".j") i addend
- ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
- stmtsAdd1
- Right sparsef -> do
- ((), stmtsInit1) <- scope $ sparsef (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i addend
-
- accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
- let a .||. b = CEBinop a "||" b
- emit $ SIf (CEBinop ixcomp "<" (CELit "0")
- .||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
- "); " ++
- "return false;")
- mempty
-
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
-
- nameidx <- compileAssign "acidx" env eidx
- nameval <- compileAssign "acval" env eval
- nameacc <- compileAssign "acac" env eacc
-
- emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- accumRef t prj (nameacc++".buf->ac") nameidx nameval
- emit $ SVerbatim $ "// compile EAccum end"
-
- incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
-
- return $ CEStruct (repSTy STNil) []
-
- EError _ t s -> do
- let padleft len c s' = replicate (len - length s) c ++ s'
- escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
- | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
- | otherwise -> [c]
- emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
- case t of
- STScal _ -> return (CELit "0")
- _ -> do
- name <- emitStruct t
- return $ CEStruct name []
-
- EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
- EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
- EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
-
- EIdx1{} -> error "Compile: not implemented: EIdx1"
-
-compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
-compileAssign prefix env e = do
- e' <- compile' env e
- case e' of
- CELit name -> return name
- _ -> do
- name <- genName' prefix
- emit $ SVarDecl True (repSTy (typeOf e)) name e'
- return name
-
-data Increment = Increment | Decrement
- deriving (Show)
-
--- | Increment reference counts in the components of the given variable.
-incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ())
-incrementVar marker inc ty =
- let tree = makeArrayTree ty
- in case tree of ATNoop -> Nothing
- _ -> Just $ \var -> incrementVar' marker inc var tree
-
-incrementVarAlways :: String -> Increment -> STy a -> String -> CompM ()
-incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty)
-
-data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array)
- | ATNoop -- ^ don't do anything here
- | ATProj String ArrayTree -- ^ descend one field deeper
- | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
- | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
- | ATBoth ArrayTree ArrayTree -- ^ do both these paths
-
-smartATProj :: String -> ArrayTree -> ArrayTree
-smartATProj _ ATNoop = ATNoop
-smartATProj field t = ATProj field t
-
-smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
-smartATCondTag ATNoop ATNoop = ATNoop
-smartATCondTag t t' = ATCondTag t t'
-
-smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
-smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
-smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
-
-smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
-smartATBoth ATNoop t = t
-smartATBoth t ATNoop = t
-smartATBoth t t' = ATBoth t t'
-
-makeArrayTree :: STy a -> ArrayTree
-makeArrayTree STNil = ATNoop
-makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
- (smartATProj "b" (makeArrayTree b))
-makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
-makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
-makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
-makeArrayTree (STScal _) = ATNoop
-makeArrayTree (STAccum _) = ATNoop
-makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
- (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
-
-incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
-incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
- case inc of
- Increment -> do
- emit $ SVerbatim (path ++ ".buf->refc++;")
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
- Decrement -> do
- case incrementVar (marker++".elt") Decrement eltty of
- Nothing ->
- if debugRefc
- then do
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
- emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
- else do
- emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
- Just f -> do
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
- shszvar <- genName' "frshsz"
- ivar <- genName' "i"
- ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
- emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0"))
- (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path)
- ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $
- eltDecrStmts
- ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"])
- mempty
-incrementVar' _ _ _ ATNoop = pure ()
-incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t
-incrementVar' marker inc path (ATCondTag t1 t2) = do
- ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
- ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
- emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
-incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
- ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
- ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
- ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
- emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
- stmts2
- (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
- stmts3
- stmts1))
-incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
-
-toLinearIdx :: SNat n -> String -> String -> CExpr
-toLinearIdx SZ _ _ = CELit "0"
-toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
-toLinearIdx (SS n) arrvar idxvar =
- CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
- "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n)))))
- "+" (CELit (idxvar ++ ".b"))
-
--- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
--- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
--- fromLinearIdx (SS n) arrvar idxvar = do
--- name <- genName
--- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
--- _
-
-data AllocMethod = Malloc | Calloc
- deriving (Show)
-
--- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
-allocArray marker method nameBase rank eltty mshsz shape = do
- when (length shape /= fromSNat rank) $
- error "allocArray: shape does not match rank"
- let arrty = STArr rank eltty
- strname <- emitStruct arrty
- arrname <- genName' nameBase
- shsz <- case mshsz of
- Just e -> return e
- Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
- let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
- "+"
- (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
- emit $ SVarDecl True strname arrname $ CEStruct strname
- [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
- Calloc -> CECall "calloc_instr" [nbytesExpr])]
- forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
- emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
- emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
- return arrname
-
-compileShapeQuery :: SNat n -> String -> CExpr
-compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
-compileShapeQuery (SS n) var =
- CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
- [("a", compileShapeQuery n var)
- ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
-
--- | Takes a variable name for the array, not the buffer.
-compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
-
--- | Takes a variable name for the array, not the buffer.
-compileArrShapeComponents :: SNat n -> String -> [CExpr]
-compileArrShapeComponents n var =
- [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
-
-indexTupleComponents :: SNat n -> String -> [CExpr]
-indexTupleComponents = \n var -> map CELit (toList (go n var))
- where
- go :: SNat n -> String -> Bag String
- go SZ _ = mempty
- go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b")
-
--- | Takes variable names with the innermost dimension on the right.
-shapeTupFromLitVars :: SNat n -> [String] -> CExpr
-shapeTupFromLitVars = \n -> go n . reverse
- where
- -- takes variables with the innermost dimension at the _head_
- go :: SNat n -> [String] -> CExpr
- go SZ [] = CEStruct (repSTy STNil) []
- go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
- go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
-
-compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
-compileOpGeneral op e1 = do
- let unary cop = return @CompM $ CECall cop [e1]
- let binary cop = do
- name <- genName
- emit $ SVarDecl True (repSTy (opt1 op)) name e1
- return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
- case op of
- OAdd _ -> binary "+"
- OMul _ -> binary "*"
- ONeg _ -> unary "-"
- OLt _ -> binary "<"
- OLe _ -> binary "<="
- OEq _ -> binary "=="
- ONot -> unary "!"
- OAnd -> binary "&&"
- OOr -> binary "||"
- OIf -> do
- name <- emitStruct (STEither STNil STNil)
- _ <- emitStruct STNil
- return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
- (CEStruct name [("tag", CELit "1")])
- ORound64 -> unary "(int64_t)round" -- ew
- OToFl64 -> unary "(double)"
- ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
- OExp STF32 -> unary "expf"
- OExp STF64 -> unary "exp"
- OLog STF32 -> unary "logf"
- OLog STF64 -> unary "log"
- OIDiv _ -> binary "/"
- OMod _ -> binary "%"
-
-compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
-compileOpPair op e1 e2 = do
- let binary cop = return @CompM $ CEBinop e1 cop e2
- case op of
- OAdd _ -> binary "+"
- OMul _ -> binary "*"
- OLt _ -> binary "<"
- OLe _ -> binary "<="
- OEq _ -> binary "=="
- OAnd -> binary "&&"
- OOr -> binary "||"
- OIDiv _ -> binary "/"
- OMod _ -> binary "%"
- _ -> error "compileOpPair: got unary operator"
-
--- | Bool: whether to ensure that the literal itself already has the appropriate type
-compileScal :: Bool -> SScalTy t -> ScalRep t -> String
-compileScal pedantic typ x = case typ of
- STI32 -> (if pedantic then "(int32_t)" else "") ++ show x
- STI64 -> (if pedantic then "(int64_t)" else "") ++ show x
- STF32 -> show x ++ "f"
- STF64 -> show x
- STBool -> if x then "1" else "0"
-
-compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr
-compileExtremum nameBase opName operator env e = do
- let STArr (SS n) t = typeOf e
- argname <- compileAssign (nameBase ++ "arg") env e
-
- zeroRefcountCheck (typeOf e) opName argname
-
- shszname <- genName' "shsz"
- -- This n is one less than the shape of the thing we're querying, which is
- -- unexpected. But it's exactly what we want, so we do it anyway.
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
-
- resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
-
- emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
-
- ivar <- genName' "i"
- jvar <- genName' "j"
- xvar <- genName
- redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
- -- we have ScalIsNumeric, so it has 1 and (<) etc. in C
- [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]"))
- ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList
- [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]"))
- ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar)
- ]
- ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)]
-
- incrementVarAlways nameBase Decrement (typeOf e) argname
-
- return (CELit resname)
-
--- | If this returns Nothing, there was nothing to copy because making a simple
--- value copy in C already makes it suitable to write to.
-copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
-copyForWriting topty var = case topty of
- SMTNil -> return Nothing
-
- SMTPair a b -> do
- e1 <- copyForWriting a (var ++ ".a")
- e2 <- copyForWriting b (var ++ ".b")
- case (e1, e2) of
- (Nothing, Nothing) -> return Nothing
- _ -> return $ Just $ CEStruct toptyname
- [("a", fromMaybe (CELit (var++".a")) e1)
- ,("b", fromMaybe (CELit (var++".b")) e2)]
-
- SMTLEither a b -> do
- (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
- (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
- case (e1, e2) of
- (Nothing, Nothing) -> return Nothing
- _ -> do
- name <- genName
- emit $ SVarDeclUninit toptyname name
- emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
- (stmts1
- <> pure (SAsg name (CEStruct toptyname
- [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
- (stmts2
- <> pure (SAsg name (CEStruct toptyname
- [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
- return (Just (CELit name))
-
- SMTMaybe t -> do
- (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
- case e1 of
- Nothing -> return Nothing
- Just e1' -> do
- name <- genName
- emit $ SVarDeclUninit toptyname name
- emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
- (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
- (stmts1
- <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
- return (Just (CELit name))
-
- -- If there are no nested arrays, we know that a refcount of 1 means that the
- -- whole thing is owned. Nested arrays have their own refcount, so with
- -- nesting we'd have to check the refcounts of all the nested arrays _too_;
- -- let's not do that. Furthermore, no sub-arrays means that the whole thing
- -- is flat, and we can just memcpy if necessary.
- SMTArr n t | not (hasArrays (fromSMTy t)) -> do
- name <- genName
- shszname <- genName' "shsz"
- emit $ SVarDeclUninit toptyname name
-
- when debugShapes $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- emit $ SVerbatim $
- "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
- concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
- ");"
-
- emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
- (pure (SAsg name (CELit var)))
- (let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
- totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
- in BList
- [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
- ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
- ,SAsg (name ++ ".buf->refc") (CELit "1")
- ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
- printCExpr 0 databytes ");"])
- return (Just (CELit name))
-
- SMTArr n t -> do
- shszname <- genName' "shsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
-
- let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
- totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
-
- name <- genName
- emit $ SVarDecl False toptyname name
- (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
- show shbytes ++ ");"
- emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-
- -- put the arrays in variables to cut short the not-quite-var chain
- dstvar <- genName' "cpydst"
- emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
- srcvar <- genName' "cpysrc"
- emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
-
- ivar <- genName' "i"
-
- (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]")
- let cpye' = case cpye of
- Just e -> e
- Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug"
-
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- cpystmts
- <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye')
-
- return (Just (CELit name))
-
- SMTScal _ -> return Nothing
-
- where
- toptyname = repSTy (fromSMTy topty)
-
-zeroRefcountCheck :: STy t -> String -> String -> CompM ()
-zeroRefcountCheck toptyp opname topvar =
- when emitChecks $ do
- mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
- case mstmts of
- Nothing -> return ()
- Just stmts -> forM_ stmts emit
- where
- -- | If this returns 'Nothing', no statements need to be generated for this type.
- go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
- go STNil _ = empty
- go (STPair a b) path = do
- (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
- return (s1 <> s2)
- go (STEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
- go (STMaybe a) path = do
- ss <- go a (path++".j")
- return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
- go (STArr n a) path = do
- ivar <- genName' "i"
- ss <- go a (path++".buf->xs["++ivar++"]")
- shszname <- genName' "shsz"
- let s1 = SVerbatim $
- "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
- "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
- "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
- let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
- let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
- return (BList [s1, s2, s3])
- go STScal{} _ = empty
- go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
- go (STLEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $
- SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
- s1
- (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
- s2
- mempty))
-
- combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
- combine (MaybeT a) (MaybeT b) = MaybeT $ do
- x <- a
- y <- b
- return $ case (x, y) of
- (Nothing, Nothing) -> Nothing
- (Just x', Nothing) -> Just (x', mempty)
- (Nothing, Just y') -> Just (mempty, y')
- (Just x', Just y') -> Just (x', y')
-
-{-# NOINLINE uniqueIdGenRef #-}
-uniqueIdGenRef :: IORef Int
-uniqueIdGenRef = unsafePerformIO $ newIORef 1
-
-compose :: Foldable t => t (a -> a) -> a -> a
-compose = foldr (.) id
-
-showPtr :: Ptr a -> String
-showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) ""
-
--- | Type-restricted.
-(^) :: Num a => a -> Int -> a
-(^) = (Prelude.^)
-
-foldl0' :: (a -> a -> a) -> a -> [a] -> a
-foldl0' _ x [] = x
-foldl0' f _ l = foldl1' f l