diff options
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 65 | 
1 files changed, 58 insertions, 7 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 302c750..3cc8934 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,10 +1,11 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-}  {-# LANGUAGE MultiWayIf #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module Compile (compile, debugCSource, debugRefc, emitChecks) where  import Control.Monad (forM_, when, replicateM)  import Control.Monad.Trans.Class (lift) @@ -24,6 +25,9 @@ import Data.Set (Set)  import Data.Some  import qualified Data.Vector as V  import Foreign +import GHC.Exts (int2Word#, addr2Int#) +import GHC.Num (integerFromWord#) +import GHC.Ptr (Ptr(..))  import Numeric (showHex)  import System.IO (hPutStrLn, stderr) @@ -51,14 +55,19 @@ let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegr  -- TODO: test that I'm properly incrementing and decrementing refcounts in all required places -debug :: Bool -debug = toEnum 0 +debugCSource, debugRefc, emitChecks :: Bool +-- | Print the generated C source +debugCSource = toEnum 0 +-- | Print extra stuff about reference counts of arrays +debugRefc = toEnum 1 +-- | Emit extra C code that checks stuff +emitChecks = toEnum 1  compile :: SList STy env -> Ex env t          -> IO (SList Value env -> IO (Rep t))  compile = \env expr -> do    let source = compileToString env expr -  when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" +  when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"    lib <- buildKernel source ["kernel"]    let arg_metrics = reverse (unSList metricsSTy env) @@ -337,6 +346,8 @@ compileToString env expr =         ,showString "void kernel(void *data) {\n"          -- Some code here assumes that we're on a 64-bit system, so let's check that         ,showString "  if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" +       ,if debugRefc then showString "  fprintf(stderr, \"[chad-kernel] Start\\n\");\n" +                     else id         ,showString $ "  *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++                         concat (map (\((arg, typ), off, idx) ->                                         "\n    *(" ++ typ ++ "*)(data + " ++ show off ++ ")" @@ -344,6 +355,8 @@ compileToString env expr =                                         ++ "  // " ++ arg)                                     (zip3 arg_pairs arg_offsets [0::Int ..])) ++                         "\n  );\n" +       ,if debugRefc then showString "  fprintf(stderr, \"[chad-kernel] Return\\n\");\n" +                     else id         ,showString "}\n"]  -- | Takes list of metrics (alignment, sizeof). @@ -381,6 +394,8 @@ serialise topty topval ptr off k =      (STArr n t, Array sh vec) -> do        let eltsz = sizeofSTy t        allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do +        when debugRefc $ +          hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr          pokeByteOff ptr off bufptr          pokeShape bufptr 0 n sh @@ -425,6 +440,8 @@ deserialise topty ptr off =        bufptr <- peekByteOff @(Ptr ()) ptr off        sh <- peekShape bufptr 0 n        refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) +      when debugRefc $ +        hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc        let off1 = 8 * fromSNat n + 8            eltsz = sizeofSTy t        arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) @@ -638,6 +655,9 @@ compile' env = \case      x0name <- compileAssign "foldx0" env ex0      arrname <- compileAssign "foldarr" env earr +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: fold1i got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" +      shszname <- genName' "shsz"      -- This n is one less than the shape of the thing we're querying, which is      -- unexpected. But it's exactly what we want, so we do it anyway. @@ -681,6 +701,9 @@ compile' env = \case      let STArr (SS n) t = typeOf e      argname <- compileAssign "sumarg" env e +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: sum1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" +      shszname <- genName' "shsz"      -- This n is one less than the shape of the thing we're querying, like EFold1Inner.      emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -722,6 +745,9 @@ compile' env = \case      lenname <- compileAssign "replen" env elen      argname <- compileAssign "reparg" env earg +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: replicate1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" +      shszname <- genName' "shsz"      emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -750,6 +776,8 @@ compile' env = \case    EIdx0 _ e -> do      let STArr _ t = typeOf e      arrname <- compileAssign "" env e +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx0 got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"      name <- genName      emit $ SVarDecl True (repSTy t) name                (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) @@ -761,6 +789,8 @@ compile' env = \case    EIdx _ earr eidx -> do      let STArr n t = typeOf earr      arrname <- compileAssign "ixarr" env earr +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"      idxname <- if fromSNat n > 0  -- prevent an unused-varable warning                   then compileAssign "ixix" env eidx                   else return ""  -- won't be used in this case @@ -774,6 +804,8 @@ compile' env = \case          t = tTup (sreplicate n tIx)      _ <- emitStruct t      name <- compileAssign "" env e +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ name ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: shape got array %p with refc=0\\n\", " ++ name ++ ".buf); abort(); }"      resname <- genName      emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)      incrementVarAlways Decrement (typeOf e) name @@ -803,6 +835,9 @@ compile' env = \case      actyname <- emitStruct (STAccum t)      name1 <- compileAssign "" env e1 +    when emitChecks $ +      emit $ SVerbatim $ "if (__builtin_expect(" ++ name1 ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: with got array %p with refc=0\\n\", " ++ name1 ++ ".buf); abort(); }" +      mcopy <- copyForWriting t name1      accname <- genName' "accum"      emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)]) @@ -936,11 +971,19 @@ makeArrayTree (STAccum _) = ATNoop  incrementVar' :: Increment -> String -> ArrayTree -> CompM ()  incrementVar' inc path (ATArray (Some n) (Some eltty)) =    case inc of -    Increment -> emit $ SVerbatim (path ++ ".buf->refc++;") -    Decrement -> +    Increment -> do +      emit $ SVerbatim (path ++ ".buf->refc++;") +      when debugRefc $ +        emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" +    Decrement -> do        case incrementVar Decrement eltty of -        Nothing -> emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);" +        Nothing -> do +          when debugRefc $ +            emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" +          emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"          Just f -> do +          when debugRefc $ +            emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"            shszvar <- genName' "frshsz"            ivar <- genName' "i"            ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") @@ -988,6 +1031,8 @@ allocArray nameBase rank eltty shsz shape = do    forM_ (zip shape [0::Int ..]) $ \(dim, i) ->      emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim    emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") +  when debugRefc $ +    emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p allocated\\n\", " ++ arrname ++ ".buf);"    return arrname  compileShapeQuery :: SNat n -> String -> CExpr @@ -1085,6 +1130,9 @@ compileExtremum nameBase opName operator env e = do    let STArr (SS n) t = typeOf e    argname <- compileAssign (nameBase ++ "arg") env e +  when emitChecks $ +    emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: " ++ opName ++ " got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" +    shszname <- genName' "shsz"    -- This n is one less than the shape of the thing we're querying, which is    -- unexpected. But it's exactly what we want, so we do it anyway. @@ -1227,6 +1275,9 @@ copyForWriting topty var = case topty of  compose :: Foldable t => t (a -> a) -> a -> a  compose = foldr (.) id +showPtr :: Ptr a -> String +showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" +  -- | Type-restricted.  (^) :: Num a => a -> Int -> a  (^) = (Prelude.^) | 
