aboutsummaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Compile.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs1796
1 files changed, 0 insertions, 1796 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
deleted file mode 100644
index 8627905..0000000
--- a/src/Compile.hs
+++ /dev/null
@@ -1,1796 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE MagicHash #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-module Compile (compile, compileStderr) where
-
-import Control.Applicative (empty)
-import Control.Monad (forM_, when, replicateM)
-import Control.Monad.Trans.Class (lift)
-import Control.Monad.Trans.Maybe
-import Control.Monad.Trans.State.Strict
-import Control.Monad.Trans.Writer.CPS
-import Data.Bifunctor (first)
-import Data.Char (ord)
-import Data.Foldable (toList)
-import Data.Functor.Const
-import qualified Data.Functor.Product as Product
-import Data.Functor.Product (Product)
-import Data.IORef
-import Data.List (foldl1', intersperse, intercalate)
-import qualified Data.Map.Strict as Map
-import Data.Maybe (fromMaybe)
-import qualified Data.Set as Set
-import Data.Set (Set)
-import Data.Some
-import qualified Data.Vector as V
-import Foreign
-import GHC.Exts (int2Word#, addr2Int#)
-import GHC.Num (integerFromWord#)
-import GHC.Ptr (Ptr(..))
-import GHC.Stack (HasCallStack)
-import Numeric (showHex)
-import System.IO (hPutStrLn, stderr)
-import System.IO.Error (mkIOError, userErrorType)
-import System.IO.Unsafe (unsafePerformIO)
-
-import Prelude hiding ((^))
-import qualified Prelude
-
-import Array
-import AST
-import AST.Pretty (ppSTy, ppExpr)
-import AST.Sparse.Types (isDense)
-import Compile.Exec
-import Data
-import Interpreter.Rep
-import qualified Util.IdGen as IdGen
-
-
--- In shape and index arrays, the innermost dimension is on the right (last index).
-
--- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
-
-
--- | Print the compiled AST
-debugPrintAST :: Bool; debugPrintAST = toEnum 0
--- | Print the generated C source
-debugCSource :: Bool; debugCSource = toEnum 0
--- | Print extra stuff about reference counts of arrays
-debugRefc :: Bool; debugRefc = toEnum 0
--- | Print some shape-related information
-debugShapes :: Bool; debugShapes = toEnum 0
--- | Print information on allocation
-debugAllocs :: Bool; debugAllocs = toEnum 0
--- | Emit extra C code that checks stuff
-emitChecks :: Bool; emitChecks = toEnum 0
-
--- | Returns compiled function plus compilation output (warnings)
-compile :: SList STy env -> Ex env t
- -> IO (SList Value env -> IO (Rep t), String)
-compile = \env expr -> do
- codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
-
- let (source, offsets) = compileToString codeID env expr
- when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
- when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
- (lib, compileOutput) <- buildKernel source "kernel"
-
- let result_type = typeOf expr
- result_size = sizeofSTy result_type
-
- let function val = do
- allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
- let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
- serialiseArguments args ptr $ do
- callKernelFun lib ptr
- ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
- when (ok /= 1) $
- ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
- deserialise result_type ptr (koResultOffset offsets)
- return (function, compileOutput)
- where
- serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
- serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
- serialise t arg ptr off $
- serialiseArguments args ptr k
- serialiseArguments _ _ k = k
-
--- | 'compile', but writes any produced C compiler output to stderr.
-compileStderr :: SList STy env -> Ex env t
- -> IO (SList Value env -> IO (Rep t))
-compileStderr env expr = do
- (fun, output) <- compile env expr
- when (not (null output)) $
- hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>"
- return fun
-
-
-data StructDecl = StructDecl
- String -- ^ name
- String -- ^ contents
- String -- ^ comment
- deriving (Show)
-
-data Stmt
- = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side
- | SVarDeclUninit String String -- ^ type, variable name (no initialiser)
- | SAsg String CExpr -- ^ variable name, right-hand side
- | SBlock (Bag Stmt)
- | SIf CExpr (Bag Stmt) (Bag Stmt)
- | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for (<type> <name> = <expr>; name < <expr>; name++) {<stmts>}
- | SVerbatim String -- ^ no implicit ';', just printed as-is
- deriving (Show)
-
-data CExpr
- = CELit String -- ^ inserted as-is, assumed no parentheses needed
- | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
- | CEProj CExpr String -- ^ field projection: expr.field
- | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field
- | CEAddrOf CExpr -- ^ &expr
- | CEIndex CExpr CExpr -- ^ expr[expr]
- | CECall String [CExpr] -- ^ function(arg1, ..., argn)
- | CEBinop CExpr String CExpr -- ^ expr + expr
- | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
- | CECast String CExpr -- ^ (<type>)<expr>
- deriving (Show)
-
-printStructDecl :: StructDecl -> ShowS
-printStructDecl (StructDecl name contents comment) =
- showString "typedef struct { " . showString contents . showString " } " . showString name
- . showString ";" . (if null comment then id else showString (" // " ++ comment))
-
-printStmt :: Int -> Stmt -> ShowS
-printStmt indent = \case
- SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
- SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
- SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
- SBlock stmts
- | null stmts -> showString "{}"
- | otherwise ->
- showString "{"
- . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
- . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
- SIf cond b1 b2 ->
- showString "if (" . printCExpr 0 cond . showString ") "
- . printStmt indent (SBlock b1)
- . (if null b2 then id else showString " else " . printStmt indent (SBlock b2))
- SLoop typ name e1 e2 stmts ->
- showString ("for (" ++ typ ++ " " ++ name ++ " = ")
- . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2
- . showString ("; " ++ name ++ "++) ")
- . printStmt indent (SBlock stmts)
- SVerbatim s -> showString s
-
--- d values:
--- * 0: top level
--- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)
--- * 2-...: various operators (see precTable)
--- * 80: address-of operator (&)
--- * 98: inside unknown operator
--- * 99: left of a field projection
--- Unlisted operators are conservatively written with full parentheses.
-printCExpr :: Int -> CExpr -> ShowS
-printCExpr d = \case
- CELit s -> showString s
- CEStruct name pairs ->
- showParen (d >= 99) $
- showString ("(" ++ name ++ "){")
- . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e
- | (n, e) <- pairs])
- . showString "}"
- CEProj e name -> printCExpr 99 e . showString ("." ++ name)
- CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name)
- CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
- CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"
- CECall n es ->
- showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
- CEBinop e1 n e2 ->
- let mprec = Map.lookup n precTable
- p = maybe (-1) fst mprec -- precedence of this operator
- (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments
- in showParen (d > p) $
- printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2
- CEIf e1 e2 e3 ->
- showParen (d > 0) $
- printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3
- CECast typ e ->
- showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e
- where
- precTable = Map.fromList
- [("||", (2, (2, 2)))
- ,("&&", (3, (3, 3)))
- ,("==", (4, (5, 5)))
- ,("!=", (4, (5, 5)))
- ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop
- ,(">", (5, (6, 6)))
- ,("<=", (5, (6, 6)))
- ,(">=", (5, (6, 6)))
- ,("+", (6, (6, 7)))
- ,("-", (6, (6, 7)))
- ,("*", (7, (7, 8)))
- ,("/", (7, (7, 8)))
- ,("%", (7, (7, 8)))]
-
-repSTy :: STy t -> String
-repSTy (STScal st) = case st of
- STI32 -> "int32_t"
- STI64 -> "int64_t"
- STF32 -> "float"
- STF64 -> "double"
- STBool -> "uint8_t"
-repSTy t = genStructName t
-
-genStructName, genArrBufStructName :: STy t -> String
-(genStructName, genArrBufStructName) =
- (\t -> "ty_" ++ gen t
- ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension
- t -> error $ "genArrBufStructName: not an array type: " ++ show t)
- where
- -- all tags start with a letter, so the array mangling is unambiguous.
- gen :: STy t -> String
- gen STNil = "n"
- gen (STPair a b) = 'P' : gen a ++ gen b
- gen (STEither a b) = 'E' : gen a ++ gen b
- gen (STLEither a b) = 'L' : gen a ++ gen b
- gen (STMaybe t) = 'M' : gen t
- gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
- gen (STScal st) = case st of
- STI32 -> "i"
- STI64 -> "j"
- STF32 -> "f"
- STF64 -> "d"
- STBool -> "b"
- gen (STAccum t) = 'C' : gen (fromSMTy t)
-
--- The subtrees contain structs used in the bodies of the structs in this node.
-data StructTree = TreeNode [StructDecl] [StructTree]
- deriving (Show)
-
--- | This function generates the actual struct declarations for each of the
--- types in our language. It thus implicitly "documents" the layout of the
--- types in the C translation.
---
--- For accumulation it is important that for struct representations of monoid
--- types, the all-zero-bytes value corresponds to the zero value of that type.
-buildStructTree :: STy t -> StructTree
-buildStructTree topty = case topty of
- STNil ->
- TreeNode [StructDecl name "" com] []
- STPair a b ->
- TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
- [buildStructTree a, buildStructTree b]
- STEither a b -> -- 0 -> l, 1 -> r
- TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
- [buildStructTree a, buildStructTree b]
- STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
- TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
- [buildStructTree a, buildStructTree b]
- STMaybe t -> -- 0 -> nothing, 1 -> just
- TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
- [buildStructTree t]
- STArr n t ->
- -- The buffer is trailed by a VLA for the actual array data.
- -- TODO: no buffer if n = 0
- TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") ""
- ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com]
- [buildStructTree t]
- STScal _ ->
- TreeNode [] []
- STAccum t ->
- TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
- ,StructDecl name (name ++ "_buf *buf;") com]
- [buildStructTree (fromSMTy t)]
- where
- name = genStructName topty
- com = ppSTy 0 topty
-
--- State: already-generated (skippable) struct names
--- Writer: the structs in declaration order
-genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) ()
-genStructTreeW (TreeNode these deps) = do
- seen <- lift get
- case filter ((`Set.notMember` seen) . nameOf) these of
- [] -> pure ()
- structs -> do
- lift $ modify (Set.fromList (map nameOf structs) <>)
- mapM_ genStructTreeW deps
- tell (BList structs)
- where
- nameOf (StructDecl name _ _) = name
-
-genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
-genAllStructs tys =
- let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys
- in toList (evalState (execWriterT m) mempty)
-
-data CompState = CompState
- { csStructs :: Set (Some STy)
- , csTopLevelDecls :: Bag String
- , csStmts :: Bag Stmt
- , csNextId :: Int }
- deriving (Show)
-
-newtype CompM a = CompM (State CompState a)
- deriving newtype (Functor, Applicative, Monad)
-
-runCompM :: CompM a -> (a, CompState)
-runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
-
-class Monad m => MonadNameGen m where genId :: m Int
-instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
-instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
-instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
-
-genName' :: MonadNameGen m => String -> m String
-genName' "" = genName
-genName' prefix = (prefix ++) . show <$> genId
-
-genName :: MonadNameGen m => m String
-genName = genName' "x"
-
-onlyIdGen :: IdGen.IdGen a -> CompM a
-onlyIdGen m = CompM $ do
- i1 <- gets csNextId
- let (res, i2) = IdGen.runIdGen' i1 m
- modify (\s -> s { csNextId = i2 })
- return res
-
-emit :: Stmt -> CompM ()
-emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
-
-scope :: CompM a -> CompM (a, Bag Stmt)
-scope m = do
- stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
- res <- m
- innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
- return (res, innerStmts)
-
-emitStruct :: STy t -> CompM String
-emitStruct ty = CompM $ do
- modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
- return (genStructName ty)
-
--- | Also returns the name of the array buffer struct
-emitArrStruct :: STy t -> CompM (String, String)
-emitArrStruct ty = CompM $ do
- modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
- return (genStructName ty, genArrBufStructName ty)
-
-emitTLD :: String -> CompM ()
-emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
-
-nameEnv :: SList f env -> SList (Const String) env
-nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
-
-data KernelOffsets = KernelOffsets
- { koArgOffsets :: [Int] -- ^ the function arguments
- , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
- , koResultOffset :: Int -- ^ the function result
- }
-
-compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
-compileToString codeID env expr =
- let args = nameEnv env
- (res, s) = runCompM (compile' args expr)
- structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
-
- (arg_pairs, arg_metrics) =
- unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
- (slistZip env args))
- (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
- result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
-
- offsets = KernelOffsets
- { koArgOffsets = arg_offsets
- , koOkResOffset = okres_offset
- , koResultOffset = result_offset }
- in (,offsets) . ($ "") $ compose
- [showString "#include <stdio.h>\n"
- ,showString "#include <stdint.h>\n"
- ,showString "#include <stdbool.h>\n"
- ,showString "#include <inttypes.h>\n"
- ,showString "#include <stdlib.h>\n"
- ,showString "#include <string.h>\n"
- ,showString "#include <math.h>\n\n"
- -- PRint-tag
- ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
-
- ,compose [printStructDecl sd . showString "\n" | sd <- structs]
- ,showString "\n"
-
- -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
- ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
- ,showString " void *ptr = malloc(n);\n"
- ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
- else id
- ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
- else id
- ,showString " return ptr;\n"
- ,showString "}\n"
- ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
- ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
- ,showString " void *ptr = calloc(n, 1);\n"
- ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
- else id
- ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
- else id
- ,showString " return ptr;\n"
- ,showString "}\n"
- ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
- ,showString "static void free_instr(void *ptr) {\n"
- ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
- else id
- ,showString " free(ptr);\n"
- ,showString "}\n\n"
-
- ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
-
- ,showString $
- "static bool typed_kernel(" ++
- repSTy (typeOf expr) ++ " *output" ++
- concatMap (", " ++)
- (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
- ") {\n"
- ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
- ,showString " *output = " . printCExpr 0 res . showString ";\n"
- ,showString " return true;\n"
- ,showString "}\n\n"
-
- ,showString "void kernel(void *data) {\n"
- -- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
- ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
- else id
- ,showString $ " const bool success = typed_kernel(" ++
- "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
- concat (map (\((arg, typ), off) ->
- ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
- ++ " /* " ++ arg ++ " */")
- (zip arg_pairs arg_offsets)) ++
- "\n );\n"
- ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
- ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
- else id
- ,showString "}\n"]
-
--- | Takes list of metrics (alignment, sizeof).
--- Returns (offsets, size of struct).
-computeStructOffsets :: [(Int, Int)] -> ([Int], Int)
-computeStructOffsets = go 0 0
- where
- go off maxal [(al, sz)] =
- ([off], align (max maxal al) (off + sz))
- go off maxal ((al, sz) : pairs@((al2,_):_)) =
- first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs
- go _ _ [] = ([], 0)
-
--- | Assumes that this is called at the correct alignment.
-serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r
-serialise topty topval ptr off k =
- -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
- case (topty, topval) of
- (STNil, ()) -> k
- (STPair a b, (x, y)) ->
- serialise a x ptr off $
- serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
- (STEither a _, Left x) -> do
- pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STEither _ b, Right y) -> do
- pokeByteOff ptr off (1 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
- (STLEither _ _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STLEither a _, Just (Left x)) -> do
- pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STLEither _ b, Just (Right y)) -> do
- pokeByteOff ptr off (2 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
- (STMaybe _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STMaybe t, Just x) -> do
- pokeByteOff ptr off (1 :: Word8)
- serialise t x ptr (off + alignmentSTy t) k
- (STArr n t, Array sh vec) -> do
- let eltsz = sizeofSTy t
- allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do
- when debugRefc $
- hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
- pokeByteOff ptr off bufptr
- pokeShape ptr (off + 8) n sh
-
- pokeByteOff @Word64 bufptr 0 (2 ^ 63)
-
- let loop i
- | i == shapeSize sh = k
- | otherwise =
- serialise t (vec V.! i) bufptr (8 + i * eltsz) $
- loop (i+1)
- loop 0
- (STScal sty, x) -> case sty of
- STI32 -> pokeByteOff ptr off (x :: Int32) >> k
- STI64 -> pokeByteOff ptr off (x :: Int64) >> k
- STF32 -> pokeByteOff ptr off (x :: Float) >> k
- STF64 -> pokeByteOff ptr off (x :: Double) >> k
- STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
- (STAccum{}, _) -> error "Cannot serialise accumulators"
-
--- | Assumes that this is called at the correct alignment.
-deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
-deserialise topty ptr off =
- -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
- case topty of
- STNil -> return ()
- STPair a b -> do
- x <- deserialise a ptr off
- y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a))
- return (x, y)
- STEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
- then Left <$> deserialise a ptr (off + alignmentSTy topty)
- else Right <$> deserialise b ptr (off + alignmentSTy topty)
- STLEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
- 0 -> return Nothing
- 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
- 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
- _ -> error "Invalid tag value"
- STMaybe t -> do
- tag <- peekByteOff @Word8 ptr off
- if tag == 0
- then return Nothing
- else Just <$> deserialise t ptr (off + alignmentSTy t)
- STArr n t -> do
- bufptr <- peekByteOff @(Ptr ()) ptr off
- sh <- peekShape ptr (off + 8) n
- refc <- peekByteOff @Word64 bufptr 0
- when debugRefc $
- hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
- let eltsz = sizeofSTy t
- arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz))
- when (refc < 2 ^ 62) $ free bufptr
- return arr
- STScal sty -> case sty of
- STI32 -> peekByteOff @Int32 ptr off
- STI64 -> peekByteOff @Int64 ptr off
- STF32 -> peekByteOff @Float ptr off
- STF64 -> peekByteOff @Double ptr off
- STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
- STAccum{} -> error "Cannot serialise accumulators"
-
-align :: Int -> Int -> Int
-align a off = (off + a - 1) `div` a * a
-
-alignmentSTy :: STy t -> Int
-alignmentSTy = fst . metricsSTy
-
-sizeofSTy :: STy t -> Int
-sizeofSTy = snd . metricsSTy
-
--- | Returns (alignment, sizeof)
-metricsSTy :: STy t -> (Int, Int)
-metricsSTy STNil = (1, 0)
-metricsSTy (STPair a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, align (max a1 a2) (s1 + s2))
-metricsSTy (STEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
-metricsSTy (STLEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
-metricsSTy (STMaybe t) =
- let (a, s) = metricsSTy t
- in (a, a + s) -- the union after the tag byte is aligned
-metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n)
-metricsSTy (STScal sty) = case sty of
- STI32 -> (4, 4)
- STI64 -> (8, 8)
- STF32 -> (4, 4)
- STF64 -> (8, 8)
- STBool -> (1, 1) -- compiled to uint8_t
-metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
-
-pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
-pokeShape ptr off = go . fromSNat
- where
- go :: Int -> Shape n -> IO ()
- go rank = \case
- ShNil -> return ()
- sh `ShCons` n -> do
- pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64)
- go (rank - 1) sh
-
-peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n)
-peekShape ptr off = \case
- SZ -> return ShNil
- SS n -> ShCons <$> peekShape ptr off n
- <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8))
-
-compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
-compile' env = \case
- EVar _ t i -> do
- let Const var = slistIdx env i
- incrementVarAlways "var" Increment t var
- return $ CELit var
-
- ELet _ rhs body -> do
- var <- compileAssign "" env rhs
- rete <- compile' (Const var `SCons` env) body
- incrementVarAlways "let" Decrement (typeOf rhs) var
- return rete
-
- EPair _ a b -> do
- name <- emitStruct (STPair (typeOf a) (typeOf b))
- e1 <- compile' env a
- e2 <- compile' env b
- return $ CEStruct name [("a", e1), ("b", e2)]
-
- EFst _ e -> do
- let STPair _ t2 = typeOf e
- e' <- compile' env e
- case incrementVar "fst" Decrement t2 of
- Nothing -> return $ CEProj e' "a"
- Just f -> do var <- genName
- emit $ SVarDecl True (repSTy (typeOf e)) var e'
- f (var ++ ".b")
- return $ CEProj (CELit var) "a"
-
- ESnd _ e -> do
- let STPair t1 _ = typeOf e
- e' <- compile' env e
- case incrementVar "snd" Decrement t1 of
- Nothing -> return $ CEProj e' "b"
- Just f -> do var <- genName
- emit $ SVarDecl True (repSTy (typeOf e)) var e'
- f (var ++ ".a")
- return $ CEProj (CELit var) "b"
-
- ENil _ -> do
- name <- emitStruct STNil
- return $ CEStruct name []
-
- EInl _ t e -> do
- name <- emitStruct (STEither (typeOf e) t)
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "0"), ("l", e1)]
-
- EInr _ t e -> do
- name <- emitStruct (STEither t (typeOf e))
- e2 <- compile' env e
- return $ CEStruct name [("tag", CELit "1"), ("r", e2)]
-
- ECase _ (EOp _ OIf e) a b -> do
- e1 <- compile' env e
- (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you
- (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SIf e1
- (stmts2 <> pure (SAsg retvar e2))
- (stmts3 <> pure (SAsg retvar e3))
- return (CELit retvar)
-
- ECase _ e a b -> do
- let STEither t1 t2 = typeOf e
- e1 <- compile' env e
- var <- genName
- -- I know those are not variable names, but it's fine, probably
- (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a
- (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b
- ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l")
- ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r")
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
- <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2
- <> stmtsRel1
- <> pure (SAsg retvar e2))
- (stmts3
- <> stmtsRel2
- <> pure (SAsg retvar e3))))
- return (CELit retvar)
-
- ENothing _ t -> do
- name <- emitStruct (STMaybe t)
- return $ CEStruct name [("tag", CELit "0")]
-
- EJust _ e -> do
- name <- emitStruct (STMaybe (typeOf e))
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "1"), ("j", e1)]
-
- EMaybe _ a b e -> do
- let STMaybe t = typeOf e
- e1 <- compile' env e
- var <- genName
- (e2, stmts2) <- scope $ compile' env a
- (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b
- ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j")
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
- <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2
- <> pure (SAsg retvar e2))
- (stmts3
- <> stmtsRel
- <> pure (SAsg retvar e3))))
- return (CELit retvar)
-
- ELNil _ t1 t2 -> do
- name <- emitStruct (STLEither t1 t2)
- return $ CEStruct name [("tag", CELit "0")]
-
- ELInl _ t e -> do
- name <- emitStruct (STLEither (typeOf e) t)
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "1"), ("l", e1)]
-
- ELInr _ t e -> do
- name <- emitStruct (STLEither t (typeOf e))
- e1 <- compile' env e
- return $ CEStruct name [("tag", CELit "2"), ("r", e1)]
-
- ELCase _ e a b c -> do
- let STLEither t1 t2 = typeOf e
- e1 <- compile' env e
- var <- genName
- (e2, stmts2) <- scope $ compile' env a
- (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b
- (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c
- ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l")
- ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r")
- retvar <- genName
- emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
- <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (stmts2 <> pure (SAsg retvar e2))
- (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1"))
- (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3))
- (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4))))))
- return (CELit retvar)
-
- EConstArr _ n t (Array sh vec) -> do
- (strname, bufstrname) <- emitArrStruct (STArr n (STScal t))
- tldname <- genName' "carraybuf"
- -- Give it a refcount of _half_ the size_t max, so that it can be
- -- incremented and decremented at will and will "never" reach anything
- -- where something happens
- emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++
- "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++
- ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};"
- return (CEStruct strname
- [("buf", CEAddrOf (CELit tldname))
- ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))])
-
- EBuild _ n esh efun -> do
- shname <- compileAssign "sh" env esh
-
- arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
-
- idxargname <- genName' "ix"
- (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
-
- linivar <- genName' "li"
- ivars <- replicateM (fromSNat n) (genName' "i")
- emit $ SBlock $
- pure (SVarDecl False "size_t" linivar (CELit "0"))
- <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
- (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx))))
- | (ivar, dimidx) <- zip ivars [0::Int ..]]
- (pure (SVarDecl True (repSTy (typeOf esh)) idxargname
- (shapeTupFromLitVars n ivars))
- <> funstmts
- <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval))
-
- return (CELit arrname)
-
- -- TODO: actually generate decent code here
- EMap _ e1 e2 -> do
- let STArr n _ = typeOf e2
- compile' env $
- elet e2 $
- EBuild ext n (EShape ext (evar IZ)) $
- elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e1
-
- EFold1Inner _ commut efun ex0 earr -> do
- let STArr (SS n) t = typeOf earr
-
- -- let vecwid = case commut of Commut -> 8 :: Int
- -- Noncommut -> 1
-
- x0name <- compileAssign "foldx0" env ex0
- arrname <- compileAssign "foldarr" env earr
-
- zeroRefcountCheck (typeOf earr) "fold1i" arrname
-
- shszname <- genName' "shsz"
- -- This n is one less than the shape of the thing we're querying, which is
- -- unexpected. But it's exactly what we want, so we do it anyway.
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
-
- resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname)
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
-
- ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name
-
- ivar <- genName' "i"
- jvar <- genName' "j"
- -- kvar <- if vecwid > 1 then genName' "k" else return ""
-
- accvar <- genName' "tot"
- pairvar <- genName' "pair" -- function input
- (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
-
- let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
- ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
- ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
-
- pairstrname <- emitStruct (STPair t t)
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- pure (SVarDecl False (repSTy t) accvar (CELit x0name))
- <> x0incrStmts -- we're copying x0 here
- <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- -- The combination function will consume the array element
- -- and the accumulator. The accumulator is replaced by
- -- what comes out of the function anyway, so that's
- -- fine, but we do need to increment the array element.
- arreltIncrStmts
- <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
- <> funStmts
- <> pure (SAsg accvar funres))
- <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
-
- incrementVarAlways "foldx0" Decrement t x0name
- incrementVarAlways "foldarr" Decrement (typeOf earr) arrname
-
- return (CELit resname)
-
- ESum1Inner _ e -> do
- let STArr (SS n) t = typeOf e
- argname <- compileAssign "sumarg" env e
-
- zeroRefcountCheck (typeOf e) "sum1i" argname
-
- shszname <- genName' "shsz"
- -- This n is one less than the shape of the thing we're querying, like EFold1Inner.
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
-
- resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
-
- let vecwid = 8 :: Int
- ivar <- genName' "i"
- jvar <- genName' "j"
- kvar <- genName' "k"
- accvar <- genName' "tot"
- let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
- -- we have ScalIsNumeric, so it has 0 and (+) in C
- [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};"
- ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $
- pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
- pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];"
- ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $
- pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];"
- ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
- pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];"
- ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))]
-
- incrementVarAlways "sum" Decrement (typeOf e) argname
-
- return (CELit resname)
-
- EUnit _ e -> do
- e' <- compile' env e
- let typ = STArr SZ (typeOf e)
- strname <- emitStruct typ
- name <- genName
- emit $ SVarDecl True strname name (CEStruct strname
- [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])])
- emit $ SAsg (name ++ ".buf->refc") (CELit "1")
- emit $ SAsg (name ++ ".buf->xs[0]") e'
- return (CELit name)
-
- EReplicate1Inner _ elen earg -> do
- let STArr n t = typeOf earg
- lenname <- compileAssign "replen" env elen
- argname <- compileAssign "reparg" env earg
-
- zeroRefcountCheck (typeOf earg) "replicate1i" argname
-
- shszname <- genName' "shsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
-
- resname <- allocArray "repl1i" Malloc "rep" (SS n) t
- (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
- (compileArrShapeComponents n argname ++ [CELit lenname])
-
- ivar <- genName' "i"
- jvar <- genName' "j"
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]")
- (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]"))
-
- incrementVarAlways "repl1i" Decrement (typeOf earg) argname
-
- return (CELit resname)
-
- EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e
-
- EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
-
- EReshape _ dim esh earg -> do
- let STArr origDim eltty = typeOf earg
- strname <- emitStruct (STArr dim eltty)
-
- shname <- compileAssign "reshsh" env esh
- arrname <- compileAssign "resharg" env earg
-
- when emitChecks $ do
- emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
- printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
- printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
- mempty
-
- return (CEStruct strname
- [("buf", CEProj (CELit arrname) "buf")
- ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
-
- -- TODO: actually generate decent code here
- EZip _ e1 e2 -> do
- let STArr n _ = typeOf e1
- compile' env $
- elet e1 $
- elet (weakenExpr WSink e2) $
- EBuild ext n (EShape ext (evar (IS IZ))) $
- EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ))
- (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ))
-
- EFold1InnerD1 _ commut efun ex0 earr -> do
- let STArr (SS n) t = typeOf earr
- STPair _ bty = typeOf efun
-
- x0name <- compileAssign "foldd1x0" env ex0
- arrname <- compileAssign "foldd1arr" env earr
-
- zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
-
- shsz1name <- genName' "shszN"
- emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
- shsz2name <- genName' "shszSN"
- emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
-
- resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
- storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
-
- ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
-
- ivar <- genName' "i"
- jvar <- genName' "j"
-
- accvar <- genName' "tot"
- pairvar <- genName' "pair" -- function input
- (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
- let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
- arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
- funresvar <- genName' "res"
- ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
-
- pairstrname <- emitStruct (STPair t t)
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
- pure (SVarDecl False (repSTy t) accvar (CELit x0name))
- <> x0incrStmts -- we're copying x0 here
- <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- -- The combination function will consume the array element
- -- and the accumulator. The accumulator is replaced by
- -- what comes out of the function anyway, so that's
- -- fine, but we do need to increment the array element.
- arreltIncrStmts
- <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
- <> funStmts
- <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
- <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
- <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
- <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
-
- incrementVarAlways "foldd1x0" Decrement t x0name
- incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
-
- strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
- return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
-
- EFold1InnerD2 _ commut efun estores ectg -> do
- let STArr n t2 = typeOf ectg
- STArr _ bty = typeOf estores
-
- storesname <- compileAssign "foldd2stores" env estores
- ctgname <- compileAssign "foldd2ctg" env ectg
-
- zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
-
- shsz1name <- genName' "shszN"
- emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
- shsz2name <- genName' "shszSN"
- emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
-
- x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
- outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
-
- ivar <- genName' "i"
- jvar <- genName' "j"
-
- accvar <- genName' "acc"
- let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
- storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
- ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
- (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
- funresvar <- genName' "res"
- ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
- ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
-
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
- pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
- <> ctgeltIncrStmts
- -- we need to loop in reverse here, but we let jvar run in the
- -- forward direction so that we can use SLoop. Note jvar is
- -- reversed in eltidx above
- <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- -- The combination function will consume the accumulator
- -- and the stores element. The accumulator is replaced by
- -- what comes out of the function anyway, so that's
- -- fine, but we do need to increment the stores element.
- storeseltIncrStmts
- <> funStmts
- <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
- <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
- <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
- <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
-
- incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
- incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
-
- strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
- return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
-
- EConst _ t x -> return $ CELit $ compileScal True t x
-
- EIdx0 _ e -> do
- let STArr _ t = typeOf e
- arrname <- compileAssign "" env e
- zeroRefcountCheck (typeOf e) "idx0" arrname
- name <- genName
- emit $ SVarDecl True (repSTy t) name
- (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
- incrementVarAlways "idx0" Decrement (STArr SZ t) arrname
- return (CELit name)
-
- -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
-
- EIdx _ earr eidx -> do
- let STArr n t = typeOf earr
- arrname <- compileAssign "ixarr" env earr
- zeroRefcountCheck (typeOf earr) "idx" arrname
- idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
- then compileAssign "ixix" env eidx
- else return "" -- won't be used in this case
-
- when emitChecks $
- forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
- emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
- (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]")))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
- arrname ++ ".buf); return false;")
- mempty
-
- resname <- genName' "ixres"
- emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname))
- incrementVarAlways "idxelt" Increment t resname
- incrementVarAlways "idx" Decrement (STArr n t) arrname
- return (CELit resname)
-
- EShape _ e -> do
- let STArr n _ = typeOf e
- t = tTup (sreplicate n tIx)
- _ <- emitStruct t
- name <- compileAssign "" env e
- zeroRefcountCheck (typeOf e) "shape" name
- resname <- genName
- emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
- incrementVarAlways "shape" Decrement (typeOf e) name
- return (CELit resname)
-
- EOp _ op (EPair _ e1 e2) -> do
- e1' <- compile' env e1
- e2' <- compile' env e2
- compileOpPair op e1' e2'
-
- EOp _ op e -> do
- e' <- compile' env e
- compileOpGeneral op e'
-
- ECustom _ _ _ _ earg _ _ e1 e2 -> do
- name1 <- compileAssign "" env e1
- name2 <- compileAssign "" env e2
- case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of
- (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
- (mfun1, mfun2) -> do
- name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg
- maybe (return ()) ($ name1) mfun1
- maybe (return ()) ($ name2) mfun2
- return (CELit name)
-
- ERecompute _ e -> compile' env e
-
- EWith _ t e1 e2 -> do
- actyname <- emitStruct (STAccum t)
- name1 <- compileAssign "" env e1
-
- zeroRefcountCheck (typeOf e1) "with" name1
-
- emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
- mcopy <- copyForWriting t name1
- accname <- genName' "accum"
- emit $ SVarDecl False actyname accname
- (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])])
- emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
- emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
-
- e2' <- compile' (Const accname `SCons` env) e2
-
- resname <- genName' "acret"
- emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac"))
- emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);"
-
- rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
- return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
-
- EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
- let -- Add a value (s) into an existing accumulation value (d). If a sparse
- -- component of d is encountered, s is copied there.
- add :: SMTy a -> String -> String -> CompM ()
- add SMTNil _ _ = return ()
- add (SMTPair t1 t2) d s = do
- add t1 (d++".a") (s++".a")
- add t2 (d++".b") (s++".b")
- add (SMTLEither t1 t2) d s = do
- ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s
- ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
- ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s))
- <> srcIncrStmts)
- ((if emitChecks
- then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0"))
- "&&"
- (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag"))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++
- "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++
- "return false;")
- mempty)
- else mempty)
- -- note: s may have tag 0
- <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
- stmts1
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2"))
- stmts2 mempty))))
- add (SMTMaybe t1) d s = do
- ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s
- ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
- emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
- (pure (SAsg d (CELit s))
- <> srcIncrStmts)
- (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty))
- add (SMTArr n t1) d s = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ [0 .. fromSNat n - 1] $ \j -> do
- emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]"))
- "!="
- (CELit (d ++ ".sh[" ++ show j ++ "]")))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++
- "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++
- d ++ ".buf" ++
- concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- ", " ++ s ++ ".buf" ++
- concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- "); " ++
- "return false;")
- mempty
-
- shsizename <- genName' "acshsz"
- emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
- ivar <- genName' "i"
- ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
- stmts1
- add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
-
- let -- | Dereference an accumulation value and add a given value to that
- -- position. Sparse components encountered along the way are
- -- initialised before proceeding downwards.
- -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
- accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
- accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
-
- accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
- accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
-
- accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
- when emitChecks $ do
- emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
- "return false;")
- mempty
- accumRef ta prj' (v++".l") i addend
- accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
- when emitChecks $ do
- emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
- "return false;")
- mempty
- accumRef tb prj' (v++".r") i addend
-
- accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
- when emitChecks $ do
- emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
- "return false;")
- mempty
- accumRef tj prj' (v++".j") i addend
-
- accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip [0::Int ..]
- (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
- let a .||. b = CEBinop a "||" b
- emit $ SIf (CEBinop ixcomp "<" (CELit "0")
- .||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]"))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
- v ++ ".buf" ++
- concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
- "); " ++
- "return false;")
- mempty
-
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
-
- nameidx <- compileAssign "acidx" env eidx
- nameval <- compileAssign "acval" env eval
- nameacc <- compileAssign "acac" env eacc
-
- emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- accumRef t prj (nameacc++".buf->ac") nameidx nameval
- emit $ SVerbatim $ "// compile EAccum end"
-
- incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
-
- return $ CEStruct (repSTy STNil) []
-
- EAccum{} ->
- error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
-
- EError _ t s -> do
- let padleft len c s' = replicate (len - length s) c ++ s'
- escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
- | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
- | otherwise -> [c]
- emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
- case t of
- STScal _ -> return (CELit "0")
- _ -> do
- name <- emitStruct t
- return $ CEStruct name []
-
- EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
- EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
- EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
- EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
-
- EIdx1{} -> error "Compile: not implemented: EIdx1"
-
-compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
-compileAssign prefix env e = do
- e' <- compile' env e
- case e' of
- CELit name -> return name
- _ -> do
- name <- genName' prefix
- emit $ SVarDecl True (repSTy (typeOf e)) name e'
- return name
-
-data Increment = Increment | Decrement
- deriving (Show)
-
--- | Increment reference counts in the components of the given variable.
-incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ())
-incrementVar marker inc ty =
- let tree = makeArrayTree ty
- in case tree of ATNoop -> Nothing
- _ -> Just $ \var -> incrementVar' marker inc var tree
-
-incrementVarAlways :: String -> Increment -> STy a -> String -> CompM ()
-incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty)
-
-data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array)
- | ATNoop -- ^ don't do anything here
- | ATProj String ArrayTree -- ^ descend one field deeper
- | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second
- | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2
- | ATBoth ArrayTree ArrayTree -- ^ do both these paths
-
-smartATProj :: String -> ArrayTree -> ArrayTree
-smartATProj _ ATNoop = ATNoop
-smartATProj field t = ATProj field t
-
-smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree
-smartATCondTag ATNoop ATNoop = ATNoop
-smartATCondTag t t' = ATCondTag t t'
-
-smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree
-smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop
-smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3
-
-smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree
-smartATBoth ATNoop t = t
-smartATBoth t ATNoop = t
-smartATBoth t t' = ATBoth t t'
-
-makeArrayTree :: STy a -> ArrayTree
-makeArrayTree STNil = ATNoop
-makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
- (smartATProj "b" (makeArrayTree b))
-makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
-makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
- (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
-makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
-makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
-makeArrayTree (STScal _) = ATNoop
-makeArrayTree (STAccum _) = ATNoop
-
-incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
-incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
- case inc of
- Increment -> do
- emit $ SVerbatim (path ++ ".buf->refc++;")
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
- Decrement -> do
- case incrementVar (marker++".elt") Decrement eltty of
- Nothing ->
- if debugRefc
- then do
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
- emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
- else do
- emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
- Just f -> do
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
- shszvar <- genName' "frshsz"
- ivar <- genName' "i"
- ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
- emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0"))
- (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path)
- ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $
- eltDecrStmts
- ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"])
- mempty
-incrementVar' _ _ _ ATNoop = pure ()
-incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t
-incrementVar' marker inc path (ATCondTag t1 t2) = do
- ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
- ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
- emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2
-incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do
- ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1
- ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2
- ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3
- emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1"))
- stmts2
- (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2"))
- stmts3
- stmts1))
-incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2
-
-toLinearIdx :: SNat n -> String -> String -> CExpr
-toLinearIdx SZ _ _ = CELit "0"
-toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
-toLinearIdx (SS n) arrvar idxvar =
- CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
- "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n)))))
- "+" (CELit (idxvar ++ ".b"))
-
--- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
--- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
--- fromLinearIdx (SS n) arrvar idxvar = do
--- name <- genName
--- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]")))
--- _
-
-data AllocMethod = Malloc | Calloc
- deriving (Show)
-
--- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
-allocArray marker method nameBase rank eltty mshsz shape = do
- when (length shape /= fromSNat rank) $
- error "allocArray: shape does not match rank"
- let arrty = STArr rank eltty
- strname <- emitStruct arrty
- arrname <- genName' nameBase
- shsz <- case mshsz of
- Just e -> return e
- Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
- let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
- "+"
- (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
- emit $ SVarDecl True strname arrname $ CEStruct strname
- [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr]
- Calloc -> CECall "calloc_instr" [nbytesExpr])
- ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))]
- emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
- return arrname
-
-compileShapeQuery :: SNat n -> String -> CExpr
-compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
-compileShapeQuery (SS n) var =
- CEStruct (repSTy (tTup (sreplicate (SS n) tIx)))
- [("a", compileShapeQuery n var)
- ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))]
-
--- | Takes a variable name for the array, not the buffer.
-compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
-
--- | Takes a variable name for the array, not the buffer.
-compileArrShapeComponents :: SNat n -> String -> [CExpr]
-compileArrShapeComponents n var =
- [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
-
-indexTupleComponents :: SNat n -> String -> [CExpr]
-indexTupleComponents = \n var -> map CELit (toList (go n var))
- where
- go :: SNat n -> String -> Bag String
- go SZ _ = mempty
- go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b")
-
--- | Takes variable names with the innermost dimension on the right.
-shapeTupFromLitVars :: SNat n -> [String] -> CExpr
-shapeTupFromLitVars = \n -> go n . reverse
- where
- -- takes variables with the innermost dimension at the _head_
- go :: SNat n -> [String] -> CExpr
- go SZ [] = CEStruct (repSTy STNil) []
- go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
- go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
-
-prodExpr :: [CExpr] -> CExpr
-prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
-
-compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
-compileOpGeneral op e1 = do
- let unary cop = return @CompM $ CECall cop [e1]
- let binary cop = do
- name <- genName
- emit $ SVarDecl True (repSTy (opt1 op)) name e1
- return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
- case op of
- OAdd _ -> binary "+"
- OMul _ -> binary "*"
- ONeg _ -> unary "-"
- OLt _ -> binary "<"
- OLe _ -> binary "<="
- OEq _ -> binary "=="
- ONot -> unary "!"
- OAnd -> binary "&&"
- OOr -> binary "||"
- OIf -> do
- name <- emitStruct (STEither STNil STNil)
- _ <- emitStruct STNil
- return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
- (CEStruct name [("tag", CELit "1")])
- ORound64 -> unary "(int64_t)round" -- ew
- OToFl64 -> unary "(double)"
- ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
- OExp STF32 -> unary "expf"
- OExp STF64 -> unary "exp"
- OLog STF32 -> unary "logf"
- OLog STF64 -> unary "log"
- OIDiv _ -> binary "/"
- OMod _ -> binary "%"
-
-compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
-compileOpPair op e1 e2 = do
- let binary cop = return @CompM $ CEBinop e1 cop e2
- case op of
- OAdd _ -> binary "+"
- OMul _ -> binary "*"
- OLt _ -> binary "<"
- OLe _ -> binary "<="
- OEq _ -> binary "=="
- OAnd -> binary "&&"
- OOr -> binary "||"
- OIDiv _ -> binary "/"
- OMod _ -> binary "%"
- _ -> error "compileOpPair: got unary operator"
-
--- | Bool: whether to ensure that the literal itself already has the appropriate type
-compileScal :: Bool -> SScalTy t -> ScalRep t -> String
-compileScal pedantic typ x = case typ of
- STI32 -> (if pedantic then "(int32_t)" else "") ++ show x
- STI64 -> (if pedantic then "(int64_t)" else "") ++ show x
- STF32 -> show x ++ "f"
- STF64 -> show x
- STBool -> if x then "1" else "0"
-
-compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr
-compileExtremum nameBase opName operator env e = do
- let STArr (SS n) t = typeOf e
- argname <- compileAssign (nameBase ++ "arg") env e
-
- zeroRefcountCheck (typeOf e) opName argname
-
- shszname <- genName' "shsz"
- -- This n is one less than the shape of the thing we're querying, which is
- -- unexpected. But it's exactly what we want, so we do it anyway.
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
-
- resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname)
-
- lenname <- genName' "n"
- emit $ SVarDecl True (repSTy tIx) lenname
- (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
-
- emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
-
- ivar <- genName' "i"
- jvar <- genName' "j"
- xvar <- genName
- redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
- -- we have ScalIsNumeric, so it has 1 and (<) etc. in C
- [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]"))
- ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList
- [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]"))
- ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar)
- ]
- ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)]
-
- incrementVarAlways nameBase Decrement (typeOf e) argname
-
- return (CELit resname)
-
--- | If this returns Nothing, there was nothing to copy because making a simple
--- value copy in C already makes it suitable to write to.
-copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr)
-copyForWriting topty var = case topty of
- SMTNil -> return Nothing
-
- SMTPair a b -> do
- e1 <- copyForWriting a (var ++ ".a")
- e2 <- copyForWriting b (var ++ ".b")
- case (e1, e2) of
- (Nothing, Nothing) -> return Nothing
- _ -> return $ Just $ CEStruct toptyname
- [("a", fromMaybe (CELit (var++".a")) e1)
- ,("b", fromMaybe (CELit (var++".b")) e2)]
-
- SMTLEither a b -> do
- (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
- (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
- case (e1, e2) of
- (Nothing, Nothing) -> return Nothing
- _ -> do
- name <- genName
- emit $ SVarDeclUninit toptyname name
- emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
- (stmts1
- <> pure (SAsg name (CEStruct toptyname
- [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
- (stmts2
- <> pure (SAsg name (CEStruct toptyname
- [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
- return (Just (CELit name))
-
- SMTMaybe t -> do
- (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
- case e1 of
- Nothing -> return Nothing
- Just e1' -> do
- name <- genName
- emit $ SVarDeclUninit toptyname name
- emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
- (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")])))
- (stmts1
- <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')])))
- return (Just (CELit name))
-
- -- If there are no nested arrays, we know that a refcount of 1 means that the
- -- whole thing is owned. Nested arrays have their own refcount, so with
- -- nesting we'd have to check the refcounts of all the nested arrays _too_;
- -- let's not do that. Furthermore, no sub-arrays means that the whole thing
- -- is flat, and we can just memcpy if necessary.
- SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do
- name <- genName
- shszname <- genName' "shsz"
- emit $ SVarDeclUninit toptyname name
-
- when debugShapes $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- emit $ SVerbatim $
- "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
- concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
- ");"
-
- emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
- (pure (SAsg name (CELit var)))
- (let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
- totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
- in BList
- [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
- ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
- ,SAsg (name ++ ".buf->refc") (CELit "1")
- ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
- printCExpr 0 databytes ");"])
- return (Just (CELit name))
-
- SMTArr n t -> do
- shszname <- genName' "shsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
-
- let shbytes = fromSNat n * 8
- databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t))))
- totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
-
- name <- genName
- emit $ SVarDecl False toptyname name
- (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])])
- emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");"
- emit $ SAsg (name ++ ".buf->refc") (CELit "1")
-
- -- put the arrays in variables to cut short the not-quite-var chain
- dstvar <- genName' "cpydst"
- emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs"))
- srcvar <- genName' "cpysrc"
- emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs"))
-
- ivar <- genName' "i"
-
- (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]")
- let cpye' = case cpye of
- Just e -> e
- Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug"
-
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
- cpystmts
- <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye')
-
- return (Just (CELit name))
-
- SMTScal _ -> return Nothing
-
- where
- toptyname = repSTy (fromSMTy topty)
-
-zeroRefcountCheck :: STy t -> String -> String -> CompM ()
-zeroRefcountCheck toptyp opname topvar =
- when emitChecks $ do
- mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
- case mstmts of
- Nothing -> return ()
- Just stmts -> forM_ stmts emit
- where
- -- | If this returns 'Nothing', no statements need to be generated for this type.
- go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
- go STNil _ = empty
- go (STPair a b) path = do
- (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
- return (s1 <> s2)
- go (STEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
- go (STLEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $
- SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
- s1
- (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
- s2
- mempty))
- go (STMaybe a) path = do
- ss <- go a (path++".j")
- return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
- go (STArr n a) path = do
- ivar <- genName' "i"
- ss <- go a (path++".buf->xs["++ivar++"]")
- shszname <- genName' "shsz"
- let s1 = SVerbatim $
- "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
- "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
- "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
- let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
- let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
- return (BList [s1, s2, s3])
- go STScal{} _ = empty
- go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
-
- combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
- combine (MaybeT a) (MaybeT b) = MaybeT $ do
- x <- a
- y <- b
- return $ case (x, y) of
- (Nothing, Nothing) -> Nothing
- (Just x', Nothing) -> Just (x', mempty)
- (Nothing, Just y') -> Just (mempty, y')
- (Just x', Just y') -> Just (x', y')
-
-{-# NOINLINE uniqueIdGenRef #-}
-uniqueIdGenRef :: IORef Int
-uniqueIdGenRef = unsafePerformIO $ newIORef 1
-
-compose :: Foldable t => t (a -> a) -> a -> a
-compose = foldr (.) id
-
-showPtr :: Ptr a -> String
-showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) ""
-
--- | Type-restricted.
-(^) :: Num a => a -> Int -> a
-(^) = (Prelude.^)
-
-foldl0' :: (a -> a -> a) -> a -> [a] -> a
-foldl0' _ x [] = x
-foldl0' f _ l = foldl1' f l