diff options
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 88 | 
1 files changed, 59 insertions, 29 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 6466065..aa23797 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -6,6 +6,7 @@  {-# LANGUAGE MagicHash #-}  {-# LANGUAGE MultiWayIf #-}  {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TupleSections #-}  {-# LANGUAGE TypeApplications #-}  module Compile (compile) where @@ -34,6 +35,7 @@ import GHC.Num (integerFromWord#)  import GHC.Ptr (Ptr(..))  import Numeric (showHex)  import System.IO (hPutStrLn, stderr) +import System.IO.Error (mkIOError, userErrorType)  import Prelude hiding ((^))  import qualified Prelude @@ -64,27 +66,28 @@ debugShapes :: Bool; debugShapes = toEnum 0  -- | Print information on allocation  debugAllocs :: Bool; debugAllocs = toEnum 0  -- | Emit extra C code that checks stuff -emitChecks :: Bool; emitChecks = toEnum 0 +emitChecks :: Bool; emitChecks = toEnum 1  compile :: SList STy env -> Ex env t          -> IO (SList Value env -> IO (Rep t))  compile = \env expr -> do -  let source = compileToString env expr +  let (source, offsets) = compileToString env expr    when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"    when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"    lib <- buildKernel source ["kernel"] -  let arg_metrics = reverse (unSList metricsSTy env) -      (arg_offsets, result_offset) = computeStructOffsets arg_metrics -      result_type = typeOf expr +  let result_type = typeOf expr        result_size = sizeofSTy result_type    return $ \val -> do -    allocaBytes (result_offset + result_size) $ \ptr -> do -      let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets +    allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do +      let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)        serialiseArguments args ptr $ do          callKernelFun "kernel" lib ptr -        deserialise result_type ptr result_offset +        ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) +        when (ok /= 1) $ +          ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) +        deserialise result_type ptr (koResultOffset offsets)    where      serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r      serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -331,7 +334,13 @@ emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <>  nameEnv :: SList f env -> SList (Const String) env  nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) -compileToString :: SList STy env -> Ex env t -> String +data KernelOffsets = KernelOffsets +  { koArgOffsets :: [Int]  -- ^ the function arguments +  , koOkResOffset :: Int  -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred +  , koResultOffset :: Int  -- ^ the function result +  } + +compileToString :: SList STy env -> Ex env t -> (String, KernelOffsets)  compileToString env expr =    let args = nameEnv env        (res, s) = runCompM (compile' args expr) @@ -340,27 +349,40 @@ compileToString env expr =        (arg_pairs, arg_metrics) =          unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))                                   (slistZip env args)) -      (arg_offsets, result_offset') = computeStructOffsets arg_metrics -      result_offset = align (alignmentSTy (typeOf expr)) result_offset' -  in ($ "") $ compose +      (arg_offsets, okres_offset) = computeStructOffsets arg_metrics +      result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) + +      offsets = KernelOffsets +        { koArgOffsets = arg_offsets +        , koOkResOffset = okres_offset +        , koResultOffset = result_offset } +  in (,offsets) . ($ "") $ compose         [showString "#include <stdio.h>\n"         ,showString "#include <stdint.h>\n" +       ,showString "#include <stdbool.h>\n"         ,showString "#include <inttypes.h>\n"         ,showString "#include <stdlib.h>\n"         ,showString "#include <string.h>\n"         ,showString "#include <math.h>\n\n" +         ,compose [printStructDecl sd . showString "\n" | sd <- structs]         ,showString "\n" + +       -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative"         ,showString "static void* malloc_instr(size_t n) {\n"         ,showString "  void *ptr = malloc(n);\n" -       ,if debugAllocs then showString "printf(\"[chad-kernel] malloc(%zu) -> %p\\n\", n, ptr);\n" +       ,if debugAllocs then showString "  printf(\"[chad-kernel] malloc(%zd) -> %p\\n\", n, ptr);\n"                         else id +       ,if emitChecks then showString "  if (ptr == NULL) { printf(\"[chad-kernel] malloc(%zd) returned NULL\\n\", n); return false; }\n" +                      else id         ,showString "  return ptr;\n"         ,showString "}\n"         ,showString "static void* calloc_instr(size_t n) {\n"         ,showString "  void *ptr = calloc(n, 1);\n" -       ,if debugAllocs then showString "printf(\"[chad-kernel] calloc(%zu) -> %p\\n\", n, ptr);\n" +       ,if debugAllocs then showString "  printf(\"[chad-kernel] calloc(%zd) -> %p\\n\", n, ptr);\n"                         else id +       ,if emitChecks then showString "  if (ptr == NULL) { printf(\"[chad-kernel] calloc(%zd, 1) returned NULL\\n\", n); return false; }\n" +                      else id         ,showString "  return ptr;\n"         ,showString "}\n"         ,showString "static void free_instr(void *ptr) {\n" @@ -368,25 +390,33 @@ compileToString env expr =                         else id         ,showString "  free(ptr);\n"         ,showString "}\n\n" +         ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] +         ,showString $ -          "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++ -            intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ +          "static bool typed_kernel(" ++ +            repSTy (typeOf expr) ++ " *output" ++ +            concatMap (", " ++) +              (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++              ") {\n"         ,compose [showString "  " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] -       ,showString "  return " . printCExpr 0 res . showString ";\n}\n\n" +       ,showString "  *output = " . printCExpr 0 res . showString ";\n" +       ,showString "  return true;\n" +       ,showString "}\n\n" +         ,showString "void kernel(void *data) {\n"          -- Some code here assumes that we're on a 64-bit system, so let's check that -       ,showString "  if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" +       ,showString $ "  if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n"         ,if debugRefc then showString "  fprintf(stderr, \"[chad-kernel] Start\\n\");\n"                       else id -       ,showString $ "  *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ -                       concat (map (\((arg, typ), off, idx) -> -                                       "\n    *(" ++ typ ++ "*)(data + " ++ show off ++ ")" -                                       ++ (if idx < length arg_pairs - 1 then "," else "") -                                       ++ "  // " ++ arg) -                                   (zip3 arg_pairs arg_offsets [0::Int ..])) ++ +       ,showString $ "  const bool success = typed_kernel(" ++ +                       "\n    (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ +                       concat (map (\((arg, typ), off) -> +                                       ",\n    *(" ++ typ ++ "*)(data + " ++ show off ++ ")" +                                       ++ "  /* " ++ arg ++ " */") +                                   (zip arg_pairs arg_offsets)) ++                         "\n  );\n" +       ,showString $ "  *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n"         ,if debugRefc then showString "  fprintf(stderr, \"[chad-kernel] Return\\n\");\n"                       else id         ,showString "}\n"] @@ -841,7 +871,7 @@ compile' env = \case                              (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))                   (pure $ SVerbatim $                      "fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++ -                      arrname ++ ".buf); abort();") +                      arrname ++ ".buf); return false;")                   mempty      resname <- genName' "ixres" @@ -998,7 +1028,7 @@ compile' env = \case                            concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++                            concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++                            "); " ++ -                          "abort();") +                          "return false;")                         mempty            accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b") @@ -1066,7 +1096,7 @@ compile' env = \case          escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]                                        | ord c < 32      -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")                                        | otherwise       -> [c] -    emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); exit(1);" +    emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;"      case t of        STScal _ -> return (CELit "0")        _ -> do @@ -1316,7 +1346,7 @@ compileExtremum nameBase opName operator env e = do    emit $ SVarDecl True (repSTy tIx) lenname                    (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) -  emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }" +  emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }"    ivar <- genName' "i"    jvar <- genName' "j" @@ -1477,7 +1507,7 @@ zeroRefcountCheck toptyp opname topvar =        let s1 = SVerbatim $                   "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++                   "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ -                 "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }" +                 "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }"        let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)        let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss        return (BList [s1, s2, s3]) | 
