summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs88
1 files changed, 59 insertions, 29 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 6466065..aa23797 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -6,6 +6,7 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Compile (compile) where
@@ -34,6 +35,7 @@ import GHC.Num (integerFromWord#)
import GHC.Ptr (Ptr(..))
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
+import System.IO.Error (mkIOError, userErrorType)
import Prelude hiding ((^))
import qualified Prelude
@@ -64,27 +66,28 @@ debugShapes :: Bool; debugShapes = toEnum 0
-- | Print information on allocation
debugAllocs :: Bool; debugAllocs = toEnum 0
-- | Emit extra C code that checks stuff
-emitChecks :: Bool; emitChecks = toEnum 0
+emitChecks :: Bool; emitChecks = toEnum 1
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
- let source = compileToString env expr
+ let (source, offsets) = compileToString env expr
when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
lib <- buildKernel source ["kernel"]
- let arg_metrics = reverse (unSList metricsSTy env)
- (arg_offsets, result_offset) = computeStructOffsets arg_metrics
- result_type = typeOf expr
+ let result_type = typeOf expr
result_size = sizeofSTy result_type
return $ \val -> do
- allocaBytes (result_offset + result_size) $ \ptr -> do
- let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets
+ allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
serialiseArguments args ptr $ do
callKernelFun "kernel" lib ptr
- deserialise result_type ptr result_offset
+ ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
+ when (ok /= 1) $
+ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
+ deserialise result_type ptr (koResultOffset offsets)
where
serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
@@ -331,7 +334,13 @@ emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <>
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
-compileToString :: SList STy env -> Ex env t -> String
+data KernelOffsets = KernelOffsets
+ { koArgOffsets :: [Int] -- ^ the function arguments
+ , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred
+ , koResultOffset :: Int -- ^ the function result
+ }
+
+compileToString :: SList STy env -> Ex env t -> (String, KernelOffsets)
compileToString env expr =
let args = nameEnv env
(res, s) = runCompM (compile' args expr)
@@ -340,27 +349,40 @@ compileToString env expr =
(arg_pairs, arg_metrics) =
unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
(slistZip env args))
- (arg_offsets, result_offset') = computeStructOffsets arg_metrics
- result_offset = align (alignmentSTy (typeOf expr)) result_offset'
- in ($ "") $ compose
+ (arg_offsets, okres_offset) = computeStructOffsets arg_metrics
+ result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1)
+
+ offsets = KernelOffsets
+ { koArgOffsets = arg_offsets
+ , koOkResOffset = okres_offset
+ , koResultOffset = result_offset }
+ in (,offsets) . ($ "") $ compose
[showString "#include <stdio.h>\n"
,showString "#include <stdint.h>\n"
+ ,showString "#include <stdbool.h>\n"
,showString "#include <inttypes.h>\n"
,showString "#include <stdlib.h>\n"
,showString "#include <string.h>\n"
,showString "#include <math.h>\n\n"
+
,compose [printStructDecl sd . showString "\n" | sd <- structs]
,showString "\n"
+
+ -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"
,showString "static void* malloc_instr(size_t n) {\n"
,showString " void *ptr = malloc(n);\n"
- ,if debugAllocs then showString "printf(\"[chad-kernel] malloc(%zu) -> %p\\n\", n, ptr);\n"
+ ,if debugAllocs then showString " printf(\"[chad-kernel] malloc(%zd) -> %p\\n\", n, ptr);\n"
else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(\"[chad-kernel] malloc(%zd) returned NULL\\n\", n); return false; }\n"
+ else id
,showString " return ptr;\n"
,showString "}\n"
,showString "static void* calloc_instr(size_t n) {\n"
,showString " void *ptr = calloc(n, 1);\n"
- ,if debugAllocs then showString "printf(\"[chad-kernel] calloc(%zu) -> %p\\n\", n, ptr);\n"
+ ,if debugAllocs then showString " printf(\"[chad-kernel] calloc(%zd) -> %p\\n\", n, ptr);\n"
else id
+ ,if emitChecks then showString " if (ptr == NULL) { printf(\"[chad-kernel] calloc(%zd, 1) returned NULL\\n\", n); return false; }\n"
+ else id
,showString " return ptr;\n"
,showString "}\n"
,showString "static void free_instr(void *ptr) {\n"
@@ -368,25 +390,33 @@ compileToString env expr =
else id
,showString " free(ptr);\n"
,showString "}\n\n"
+
,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
+
,showString $
- "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++
- intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
+ "static bool typed_kernel(" ++
+ repSTy (typeOf expr) ++ " *output" ++
+ concatMap (", " ++)
+ (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
") {\n"
,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
- ,showString " return " . printCExpr 0 res . showString ";\n}\n\n"
+ ,showString " *output = " . printCExpr 0 res . showString ";\n"
+ ,showString " return true;\n"
+ ,showString "}\n\n"
+
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"
+ ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"
,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Start\\n\");\n"
else id
- ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
- concat (map (\((arg, typ), off, idx) ->
- "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
- ++ (if idx < length arg_pairs - 1 then "," else "")
- ++ " // " ++ arg)
- (zip3 arg_pairs arg_offsets [0::Int ..])) ++
+ ,showString $ " const bool success = typed_kernel(" ++
+ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++
+ concat (map (\((arg, typ), off) ->
+ ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
+ ++ " /* " ++ arg ++ " */")
+ (zip arg_pairs arg_offsets)) ++
"\n );\n"
+ ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"
,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Return\\n\");\n"
else id
,showString "}\n"]
@@ -841,7 +871,7 @@ compile' env = \case
(CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
(pure $ SVerbatim $
"fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++
- arrname ++ ".buf); abort();")
+ arrname ++ ".buf); return false;")
mempty
resname <- genName' "ixres"
@@ -998,7 +1028,7 @@ compile' env = \case
concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++
"); " ++
- "abort();")
+ "return false;")
mempty
accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b")
@@ -1066,7 +1096,7 @@ compile' env = \case
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
| ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
| otherwise -> [c]
- emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); exit(1);"
+ emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"
case t of
STScal _ -> return (CELit "0")
_ -> do
@@ -1316,7 +1346,7 @@ compileExtremum nameBase opName operator env e = do
emit $ SVarDecl True (repSTy tIx) lenname
(CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
- emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }"
+ emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"
ivar <- genName' "i"
jvar <- genName' "j"
@@ -1477,7 +1507,7 @@ zeroRefcountCheck toptyp opname topvar =
let s1 = SVerbatim $
"if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
"fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
- "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }"
+ "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"
let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
return (BList [s1, s2, s3])