From 4551cf775dcc099a8a534359e92c5ddc9349ac9f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 21 Apr 2025 23:19:51 +0200 Subject: compile: More checks, don't crash on check fail --- src/Compile.hs | 88 +++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/src/Compile.hs b/src/Compile.hs index 6466065..aa23797 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -6,6 +6,7 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} module Compile (compile) where @@ -34,6 +35,7 @@ import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) import Numeric (showHex) import System.IO (hPutStrLn, stderr) +import System.IO.Error (mkIOError, userErrorType) import Prelude hiding ((^)) import qualified Prelude @@ -64,27 +66,28 @@ debugShapes :: Bool; debugShapes = toEnum 0 -- | Print information on allocation debugAllocs :: Bool; debugAllocs = toEnum 0 -- | Emit extra C code that checks stuff -emitChecks :: Bool; emitChecks = toEnum 0 +emitChecks :: Bool; emitChecks = toEnum 1 compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) compile = \env expr -> do - let source = compileToString env expr + let (source, offsets) = compileToString env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" lib <- buildKernel source ["kernel"] - let arg_metrics = reverse (unSList metricsSTy env) - (arg_offsets, result_offset) = computeStructOffsets arg_metrics - result_type = typeOf expr + let result_type = typeOf expr result_size = sizeofSTy result_type return $ \val -> do - allocaBytes (result_offset + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) serialiseArguments args ptr $ do callKernelFun "kernel" lib ptr - deserialise result_type ptr result_offset + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) where serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -331,7 +334,13 @@ emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) -compileToString :: SList STy env -> Ex env t -> String +data KernelOffsets = KernelOffsets + { koArgOffsets :: [Int] -- ^ the function arguments + , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred + , koResultOffset :: Int -- ^ the function result + } + +compileToString :: SList STy env -> Ex env t -> (String, KernelOffsets) compileToString env expr = let args = nameEnv env (res, s) = runCompM (compile' args expr) @@ -340,27 +349,40 @@ compileToString env expr = (arg_pairs, arg_metrics) = unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) (slistZip env args)) - (arg_offsets, result_offset') = computeStructOffsets arg_metrics - result_offset = align (alignmentSTy (typeOf expr)) result_offset' - in ($ "") $ compose + (arg_offsets, okres_offset) = computeStructOffsets arg_metrics + result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) + + offsets = KernelOffsets + { koArgOffsets = arg_offsets + , koOkResOffset = okres_offset + , koResultOffset = result_offset } + in (,offsets) . ($ "") $ compose [showString "#include \n" ,showString "#include \n" + ,showString "#include \n" ,showString "#include \n" ,showString "#include \n" ,showString "#include \n" ,showString "#include \n\n" + ,compose [printStructDecl sd . showString "\n" | sd <- structs] ,showString "\n" + + -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative" ,showString "static void* malloc_instr(size_t n) {\n" ,showString " void *ptr = malloc(n);\n" - ,if debugAllocs then showString "printf(\"[chad-kernel] malloc(%zu) -> %p\\n\", n, ptr);\n" + ,if debugAllocs then showString " printf(\"[chad-kernel] malloc(%zd) -> %p\\n\", n, ptr);\n" else id + ,if emitChecks then showString " if (ptr == NULL) { printf(\"[chad-kernel] malloc(%zd) returned NULL\\n\", n); return false; }\n" + else id ,showString " return ptr;\n" ,showString "}\n" ,showString "static void* calloc_instr(size_t n) {\n" ,showString " void *ptr = calloc(n, 1);\n" - ,if debugAllocs then showString "printf(\"[chad-kernel] calloc(%zu) -> %p\\n\", n, ptr);\n" + ,if debugAllocs then showString " printf(\"[chad-kernel] calloc(%zd) -> %p\\n\", n, ptr);\n" else id + ,if emitChecks then showString " if (ptr == NULL) { printf(\"[chad-kernel] calloc(%zd, 1) returned NULL\\n\", n); return false; }\n" + else id ,showString " return ptr;\n" ,showString "}\n" ,showString "static void free_instr(void *ptr) {\n" @@ -368,25 +390,33 @@ compileToString env expr = else id ,showString " free(ptr);\n" ,showString "}\n\n" + ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] + ,showString $ - "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++ - intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ + "static bool typed_kernel(" ++ + repSTy (typeOf expr) ++ " *output" ++ + concatMap (", " ++) + (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ ") {\n" ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] - ,showString " return " . printCExpr 0 res . showString ";\n}\n\n" + ,showString " *output = " . printCExpr 0 res . showString ";\n" + ,showString " return true;\n" + ,showString "}\n\n" + ,showString "void kernel(void *data) {\n" -- Some code here assumes that we're on a 64-bit system, so let's check that - ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" + ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n" ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Start\\n\");\n" else id - ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ - concat (map (\((arg, typ), off, idx) -> - "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ (if idx < length arg_pairs - 1 then "," else "") - ++ " // " ++ arg) - (zip3 arg_pairs arg_offsets [0::Int ..])) ++ + ,showString $ " const bool success = typed_kernel(" ++ + "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ + concat (map (\((arg, typ), off) -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + (zip arg_pairs arg_offsets)) ++ "\n );\n" + ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Return\\n\");\n" else id ,showString "}\n"] @@ -841,7 +871,7 @@ compile' env = \case (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) (pure $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++ - arrname ++ ".buf); abort();") + arrname ++ ".buf); return false;") mempty resname <- genName' "ixres" @@ -998,7 +1028,7 @@ compile' env = \case concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++ "); " ++ - "abort();") + "return false;") mempty accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b") @@ -1066,7 +1096,7 @@ compile' env = \case escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") | otherwise -> [c] - emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); exit(1);" + emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;" case t of STScal _ -> return (CELit "0") _ -> do @@ -1316,7 +1346,7 @@ compileExtremum nameBase opName operator env e = do emit $ SVarDecl True (repSTy tIx) lenname (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) - emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }" + emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" ivar <- genName' "i" jvar <- genName' "j" @@ -1477,7 +1507,7 @@ zeroRefcountCheck toptyp opname topvar = let s1 = SVerbatim $ "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ - "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }" + "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }" let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss return (BList [s1, s2, s3]) -- cgit v1.2.3-70-g09d2