summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-22 11:04:06 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-22 11:04:06 +0100
commitb87518c60f3034411bffc0c4745141db6a8d81d3 (patch)
treebc83659bfbf5022c90a97cfec85decdea9883445
parent0a3e53d7b40d2009aca66d2cafd555c2b1d858bb (diff)
Compile: More debugging machinery
-rw-r--r--src/Compile.hs65
1 files changed, 58 insertions, 7 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 302c750..3cc8934 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,10 +1,11 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
-module Compile (compile) where
+module Compile (compile, debugCSource, debugRefc, emitChecks) where
import Control.Monad (forM_, when, replicateM)
import Control.Monad.Trans.Class (lift)
@@ -24,6 +25,9 @@ import Data.Set (Set)
import Data.Some
import qualified Data.Vector as V
import Foreign
+import GHC.Exts (int2Word#, addr2Int#)
+import GHC.Num (integerFromWord#)
+import GHC.Ptr (Ptr(..))
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
@@ -51,14 +55,19 @@ let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegr
-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
-debug :: Bool
-debug = toEnum 0
+debugCSource, debugRefc, emitChecks :: Bool
+-- | Print the generated C source
+debugCSource = toEnum 0
+-- | Print extra stuff about reference counts of arrays
+debugRefc = toEnum 1
+-- | Emit extra C code that checks stuff
+emitChecks = toEnum 1
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
let source = compileToString env expr
- when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
+ when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
lib <- buildKernel source ["kernel"]
let arg_metrics = reverse (unSList metricsSTy env)
@@ -337,6 +346,8 @@ compileToString env expr =
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"
+ ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Start\\n\");\n"
+ else id
,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
concat (map (\((arg, typ), off, idx) ->
"\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
@@ -344,6 +355,8 @@ compileToString env expr =
++ " // " ++ arg)
(zip3 arg_pairs arg_offsets [0::Int ..])) ++
"\n );\n"
+ ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Return\\n\");\n"
+ else id
,showString "}\n"]
-- | Takes list of metrics (alignment, sizeof).
@@ -381,6 +394,8 @@ serialise topty topval ptr off k =
(STArr n t, Array sh vec) -> do
let eltsz = sizeofSTy t
allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr
pokeByteOff ptr off bufptr
pokeShape bufptr 0 n sh
@@ -425,6 +440,8 @@ deserialise topty ptr off =
bufptr <- peekByteOff @(Ptr ()) ptr off
sh <- peekShape bufptr 0 n
refc <- peekByteOff @Word64 bufptr (8 * fromSNat n)
+ when debugRefc $
+ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc
let off1 = 8 * fromSNat n + 8
eltsz = sizeofSTy t
arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz))
@@ -638,6 +655,9 @@ compile' env = \case
x0name <- compileAssign "foldx0" env ex0
arrname <- compileAssign "foldarr" env earr
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: fold1i got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"
+
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
-- unexpected. But it's exactly what we want, so we do it anyway.
@@ -681,6 +701,9 @@ compile' env = \case
let STArr (SS n) t = typeOf e
argname <- compileAssign "sumarg" env e
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: sum1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }"
+
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
@@ -722,6 +745,9 @@ compile' env = \case
lenname <- compileAssign "replen" env elen
argname <- compileAssign "reparg" env earg
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: replicate1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }"
+
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
@@ -750,6 +776,8 @@ compile' env = \case
EIdx0 _ e -> do
let STArr _ t = typeOf e
arrname <- compileAssign "" env e
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx0 got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"
name <- genName
emit $ SVarDecl True (repSTy t) name
(CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
@@ -761,6 +789,8 @@ compile' env = \case
EIdx _ earr eidx -> do
let STArr n t = typeOf earr
arrname <- compileAssign "ixarr" env earr
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"
idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
then compileAssign "ixix" env eidx
else return "" -- won't be used in this case
@@ -774,6 +804,8 @@ compile' env = \case
t = tTup (sreplicate n tIx)
_ <- emitStruct t
name <- compileAssign "" env e
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ name ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: shape got array %p with refc=0\\n\", " ++ name ++ ".buf); abort(); }"
resname <- genName
emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
incrementVarAlways Decrement (typeOf e) name
@@ -803,6 +835,9 @@ compile' env = \case
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ name1 ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: with got array %p with refc=0\\n\", " ++ name1 ++ ".buf); abort(); }"
+
mcopy <- copyForWriting t name1
accname <- genName' "accum"
emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])
@@ -936,11 +971,19 @@ makeArrayTree (STAccum _) = ATNoop
incrementVar' :: Increment -> String -> ArrayTree -> CompM ()
incrementVar' inc path (ATArray (Some n) (Some eltty)) =
case inc of
- Increment -> emit $ SVerbatim (path ++ ".buf->refc++;")
- Decrement ->
+ Increment -> do
+ emit $ SVerbatim (path ++ ".buf->refc++;")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
+ Decrement -> do
case incrementVar Decrement eltty of
- Nothing -> emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);"
+ Nothing -> do
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
Just f -> do
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
shszvar <- genName' "frshsz"
ivar <- genName' "i"
((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]")
@@ -988,6 +1031,8 @@ allocArray nameBase rank eltty shsz shape = do
forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
+ when debugRefc $
+ emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p allocated\\n\", " ++ arrname ++ ".buf);"
return arrname
compileShapeQuery :: SNat n -> String -> CExpr
@@ -1085,6 +1130,9 @@ compileExtremum nameBase opName operator env e = do
let STArr (SS n) t = typeOf e
argname <- compileAssign (nameBase ++ "arg") env e
+ when emitChecks $
+ emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: " ++ opName ++ " got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }"
+
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
-- unexpected. But it's exactly what we want, so we do it anyway.
@@ -1227,6 +1275,9 @@ copyForWriting topty var = case topty of
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
+showPtr :: Ptr a -> String
+showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) ""
+
-- | Type-restricted.
(^) :: Num a => a -> Int -> a
(^) = (Prelude.^)