diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-22 11:04:06 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-22 11:04:06 +0100 |
commit | b87518c60f3034411bffc0c4745141db6a8d81d3 (patch) | |
tree | bc83659bfbf5022c90a97cfec85decdea9883445 | |
parent | 0a3e53d7b40d2009aca66d2cafd555c2b1d858bb (diff) |
Compile: More debugging machinery
-rw-r--r-- | src/Compile.hs | 65 |
1 files changed, 58 insertions, 7 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 302c750..3cc8934 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module Compile (compile, debugCSource, debugRefc, emitChecks) where import Control.Monad (forM_, when, replicateM) import Control.Monad.Trans.Class (lift) @@ -24,6 +25,9 @@ import Data.Set (Set) import Data.Some import qualified Data.Vector as V import Foreign +import GHC.Exts (int2Word#, addr2Int#) +import GHC.Num (integerFromWord#) +import GHC.Ptr (Ptr(..)) import Numeric (showHex) import System.IO (hPutStrLn, stderr) @@ -51,14 +55,19 @@ let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegr -- TODO: test that I'm properly incrementing and decrementing refcounts in all required places -debug :: Bool -debug = toEnum 0 +debugCSource, debugRefc, emitChecks :: Bool +-- | Print the generated C source +debugCSource = toEnum 0 +-- | Print extra stuff about reference counts of arrays +debugRefc = toEnum 1 +-- | Emit extra C code that checks stuff +emitChecks = toEnum 1 compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) compile = \env expr -> do let source = compileToString env expr - when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" + when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" lib <- buildKernel source ["kernel"] let arg_metrics = reverse (unSList metricsSTy env) @@ -337,6 +346,8 @@ compileToString env expr = ,showString "void kernel(void *data) {\n" -- Some code here assumes that we're on a 64-bit system, so let's check that ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" + ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Start\\n\");\n" + else id ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ concat (map (\((arg, typ), off, idx) -> "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" @@ -344,6 +355,8 @@ compileToString env expr = ++ " // " ++ arg) (zip3 arg_pairs arg_offsets [0::Int ..])) ++ "\n );\n" + ,if debugRefc then showString " fprintf(stderr, \"[chad-kernel] Return\\n\");\n" + else id ,showString "}\n"] -- | Takes list of metrics (alignment, sizeof). @@ -381,6 +394,8 @@ serialise topty topval ptr off k = (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + when debugRefc $ + hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr pokeShape bufptr 0 n sh @@ -425,6 +440,8 @@ deserialise topty ptr off = bufptr <- peekByteOff @(Ptr ()) ptr off sh <- peekShape bufptr 0 n refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + when debugRefc $ + hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc let off1 = 8 * fromSNat n + 8 eltsz = sizeofSTy t arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) @@ -638,6 +655,9 @@ compile' env = \case x0name <- compileAssign "foldx0" env ex0 arrname <- compileAssign "foldarr" env earr + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: fold1i got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" + shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is -- unexpected. But it's exactly what we want, so we do it anyway. @@ -681,6 +701,9 @@ compile' env = \case let STArr (SS n) t = typeOf e argname <- compileAssign "sumarg" env e + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: sum1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" + shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -722,6 +745,9 @@ compile' env = \case lenname <- compileAssign "replen" env elen argname <- compileAssign "reparg" env earg + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: replicate1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" + shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -750,6 +776,8 @@ compile' env = \case EIdx0 _ e -> do let STArr _ t = typeOf e arrname <- compileAssign "" env e + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx0 got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" name <- genName emit $ SVarDecl True (repSTy t) name (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) @@ -761,6 +789,8 @@ compile' env = \case EIdx _ earr eidx -> do let STArr n t = typeOf earr arrname <- compileAssign "ixarr" env earr + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" idxname <- if fromSNat n > 0 -- prevent an unused-varable warning then compileAssign "ixix" env eidx else return "" -- won't be used in this case @@ -774,6 +804,8 @@ compile' env = \case t = tTup (sreplicate n tIx) _ <- emitStruct t name <- compileAssign "" env e + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ name ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: shape got array %p with refc=0\\n\", " ++ name ++ ".buf); abort(); }" resname <- genName emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) incrementVarAlways Decrement (typeOf e) name @@ -803,6 +835,9 @@ compile' env = \case actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ name1 ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: with got array %p with refc=0\\n\", " ++ name1 ++ ".buf); abort(); }" + mcopy <- copyForWriting t name1 accname <- genName' "accum" emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)]) @@ -936,11 +971,19 @@ makeArrayTree (STAccum _) = ATNoop incrementVar' :: Increment -> String -> ArrayTree -> CompM () incrementVar' inc path (ATArray (Some n) (Some eltty)) = case inc of - Increment -> emit $ SVerbatim (path ++ ".buf->refc++;") - Decrement -> + Increment -> do + emit $ SVerbatim (path ++ ".buf->refc++;") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" + Decrement -> do case incrementVar Decrement eltty of - Nothing -> emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);" + Nothing -> do + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" Just f -> do + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" shszvar <- genName' "frshsz" ivar <- genName' "i" ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") @@ -988,6 +1031,8 @@ allocArray nameBase rank eltty shsz shape = do forM_ (zip shape [0::Int ..]) $ \(dim, i) -> emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p allocated\\n\", " ++ arrname ++ ".buf);" return arrname compileShapeQuery :: SNat n -> String -> CExpr @@ -1085,6 +1130,9 @@ compileExtremum nameBase opName operator env e = do let STArr (SS n) t = typeOf e argname <- compileAssign (nameBase ++ "arg") env e + when emitChecks $ + emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: " ++ opName ++ " got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" + shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is -- unexpected. But it's exactly what we want, so we do it anyway. @@ -1227,6 +1275,9 @@ copyForWriting topty var = case topty of compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id +showPtr :: Ptr a -> String +showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" + -- | Type-restricted. (^) :: Num a => a -> Int -> a (^) = (Prelude.^) |