summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Compile.hs52
1 files changed, 32 insertions, 20 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index fe99c4d..e3eb207 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -22,6 +22,7 @@ import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
import Data.Functor.Product (Product)
+import Data.IORef
import Data.List (foldl1', intersperse, intercalate)
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
@@ -36,6 +37,7 @@ import GHC.Ptr (Ptr(..))
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
import System.IO.Error (mkIOError, userErrorType)
+import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding ((^))
import qualified Prelude
@@ -71,7 +73,9 @@ emitChecks :: Bool; emitChecks = toEnum 0
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
- let (source, offsets) = compileToString env expr
+ codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
+
+ let (source, offsets) = compileToString codeID env expr
when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
lib <- buildKernel source ["kernel"]
@@ -340,8 +344,8 @@ data KernelOffsets = KernelOffsets
, koResultOffset :: Int -- ^ the function result
}
-compileToString :: SList STy env -> Ex env t -> (String, KernelOffsets)
-compileToString env expr =
+compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
+compileToString codeID env expr =
let args = nameEnv env
(res, s) = runCompM (compile' args expr)
structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
@@ -364,29 +368,33 @@ compileToString env expr =
,showString "#include <stdlib.h>\n"
,showString "#include <string.h>\n"
,showString "#include <math.h>\n\n"
+ -- PRint-tag
+ ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n"
,compose [printStructDecl sd . showString "\n" | sd <- structs]
,showString "\n"
-- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
- ,showString "static void* malloc_instr(size_t n) {\n"
+ ,showString "static void* malloc_instr_fun(size_t n, int line) {\n"
,showString " void *ptr = malloc(n);\n"
- ,if debugAllocs then showString " printf(\"[chad-kernel] malloc(%zd) -> %p\\n\", n, ptr);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n"
else id
- ,if emitChecks then showString " if (ptr == NULL) { printf(\"[chad-kernel] malloc(%zd) returned NULL\\n\", n); return false; }\n"
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n"
else id
,showString " return ptr;\n"
,showString "}\n"
- ,showString "static void* calloc_instr(size_t n) {\n"
+ ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
+ ,showString "static void* calloc_instr_fun(size_t n, int line) {\n"
,showString " void *ptr = calloc(n, 1);\n"
- ,if debugAllocs then showString " printf(\"[chad-kernel] calloc(%zd) -> %p\\n\", n, ptr);\n"
+ ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n"
else id
- ,if emitChecks then showString " if (ptr == NULL) { printf(\"[chad-kernel] calloc(%zd, 1) returned NULL\\n\", n); return false; }\n"
+ ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n"
else id
,showString " return ptr;\n"
,showString "}\n"
+ ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n"
,showString "static void free_instr(void *ptr) {\n"
- ,if debugAllocs then showString "printf(\"[chad-kernel] free(%p)\\n\", ptr);\n"
+ ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n"
else id
,showString " free(ptr);\n"
,showString "}\n\n"
@@ -407,7 +415,7 @@ compileToString env expr =
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
- ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Start\\n\");\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n"
else id
,showString $ " const bool success = typed_kernel(" ++
"\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
@@ -417,7 +425,7 @@ compileToString env expr =
(zip arg_pairs arg_offsets)) ++
"\n );\n"
,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
- ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Return\\n\");\n"
+ ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n"
else id
,showString "}\n"]
@@ -870,7 +878,7 @@ compile' env = \case
emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
(CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
(pure $ SVerbatim $
- "fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++
+ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++
arrname ++ ".buf); return false;")
mempty
@@ -1021,7 +1029,7 @@ compile' env = \case
.||.
CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
- "fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
"arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
v ++ ".j.buf" ++
concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
@@ -1168,19 +1176,19 @@ incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
Increment -> do
emit $ SVerbatim (path ++ ".buf->refc++;")
when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
Decrement -> do
case incrementVar (marker++".elt") Decrement eltty of
Nothing ->
if debugRefc
then do
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
else do
emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);"
Just f -> do
when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
shszvar <- genName' "frshsz"
ivar <- genName' "i"
((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
@@ -1237,7 +1245,7 @@ allocArray marker method nameBase rank eltty mshsz shape = do
emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
+ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);"
return arrname
compileShapeQuery :: SNat n -> String -> CExpr
@@ -1423,7 +1431,7 @@ copyForWriting topty var = case topty of
when debugShapes $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
emit $ SVerbatim $
- "fprintf(stderr, \"[chad-kernel] with array " ++ shfmt ++ "\\n\"" ++
+ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++
concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
");"
@@ -1506,7 +1514,7 @@ zeroRefcountCheck toptyp opname topvar =
shszname <- genName' "shsz"
let s1 = SVerbatim $
"if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
- "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
+ "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++
"%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
@@ -1524,6 +1532,10 @@ zeroRefcountCheck toptyp opname topvar =
(Nothing, Just y') -> Just (mempty, y')
(Just x', Just y') -> Just (x', y')
+{-# NOINLINE uniqueIdGenRef #-}
+uniqueIdGenRef :: IORef Int
+uniqueIdGenRef = unsafePerformIO $ newIORef 1
+
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id