summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-23 11:59:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-23 11:59:17 +0100
commitfa1906183e91a3f0fadd27a29375b860ac40e53c (patch)
tree6c18b3c064e2e2d5c6d8dd549d59303f434bc997
parentb87518c60f3034411bffc0c4745141db6a8d81d3 (diff)
Compile: More checkstest-compile-chad
-rw-r--r--src/Compile.hs123
1 files changed, 80 insertions, 43 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 3cc8934..09c3ed5 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -42,26 +42,19 @@ import Data
import Interpreter.Rep
-{-
-:m *Example Compile AST.UnMonoid
-:seti -XOverloadedLabels -XGADTs
-let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i :: Double) in (($ SCons (Value array) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ #x ! pair nil (round_ (#x ! pair nil 3))
-(($ SNil) =<<) $ compile knownEnv $ fromNamed $ body $ build2 5 3 (#i :-> #j :-> 10 * #i + #j)
--}
-
-
-- In shape and index arrays, the innermost dimension is on the right (last index).
-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places
-debugCSource, debugRefc, emitChecks :: Bool
-- | Print the generated C source
-debugCSource = toEnum 0
+debugCSource :: Bool; debugCSource = toEnum 0
-- | Print extra stuff about reference counts of arrays
-debugRefc = toEnum 1
+debugRefc :: Bool; debugRefc = toEnum 0
+-- | Print some shape-related information
+debugShapes :: Bool; debugShapes = toEnum 0
-- | Emit extra C code that checks stuff
-emitChecks = toEnum 1
+emitChecks :: Bool; emitChecks = toEnum 0
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
@@ -331,18 +324,19 @@ compileToString env expr =
in ($ "") $ compose
[showString "#include <stdio.h>\n"
,showString "#include <stdint.h>\n"
+ ,showString "#include <inttypes.h>\n"
,showString "#include <stdlib.h>\n"
,showString "#include <string.h>\n"
,showString "#include <math.h>\n\n"
- ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs
+ ,compose [printStructDecl sd . showString "\n" | sd <- structs]
,showString "\n"
,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
,showString $
"static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++
intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
") {\n"
- ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s))
- ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n"
+ ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)]
+ ,showString " return " . printCExpr 0 res . showString ";\n}\n\n"
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"
@@ -631,7 +625,7 @@ compile' env = \case
shsizename <- genName' "shsz"
emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname)
- arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname)
+ arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (indexTupleComponents n shname)
idxargname <- genName' "ix"
(funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
@@ -655,8 +649,7 @@ compile' env = \case
x0name <- compileAssign "foldx0" env ex0
arrname <- compileAssign "foldarr" env earr
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: fold1i got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"
+ zeroRefcountCheck "fold1i" arrname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
@@ -701,8 +694,7 @@ compile' env = \case
let STArr (SS n) t = typeOf e
argname <- compileAssign "sumarg" env e
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: sum1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }"
+ zeroRefcountCheck "sum1i" argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
@@ -745,8 +737,7 @@ compile' env = \case
lenname <- compileAssign "replen" env elen
argname <- compileAssign "reparg" env earg
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: replicate1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }"
+ zeroRefcountCheck "replicate1i" argname
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
@@ -776,8 +767,7 @@ compile' env = \case
EIdx0 _ e -> do
let STArr _ t = typeOf e
arrname <- compileAssign "" env e
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx0 got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"
+ zeroRefcountCheck "idx0" arrname
name <- genName
emit $ SVarDecl True (repSTy t) name
(CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
@@ -789,11 +779,20 @@ compile' env = \case
EIdx _ earr eidx -> do
let STArr n t = typeOf earr
arrname <- compileAssign "ixarr" env earr
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }"
+ zeroRefcountCheck "idx" arrname
idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
then compileAssign "ixix" env eidx
else return "" -- won't be used in this case
+
+ when emitChecks $
+ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) ->
+ emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||"
+ (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]")))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++
+ arrname ++ ".buf); abort();")
+ mempty
+
resname <- genName' "ixres"
emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname))
incrementVarAlways Decrement (STArr n t) arrname
@@ -804,8 +803,7 @@ compile' env = \case
t = tTup (sreplicate n tIx)
_ <- emitStruct t
name <- compileAssign "" env e
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ name ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: shape got array %p with refc=0\\n\", " ++ name ++ ".buf); abort(); }"
+ zeroRefcountCheck "shape" name
resname <- genName
emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
incrementVarAlways Decrement (typeOf e) name
@@ -835,8 +833,7 @@ compile' env = \case
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ name1 ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: with got array %p with refc=0\\n\", " ++ name1 ++ ".buf); abort(); }"
+ zeroRefcountCheck "with" name1
mcopy <- copyForWriting t name1
accname <- genName' "accum"
@@ -852,14 +849,36 @@ compile' env = \case
nameval <- compileAssign "acval" env eval
nameacc <- compileAssign "acac" env eacc
- let accumRef :: STy a -> SAcPrj p a b -> String -> String -> String
- accumRef _ SAPHere v _ = v
+ let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String
+ accumRef _ SAPHere v _ = pure v
accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i
accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i
accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i
accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i
accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i
- accumRef (STArr n t') (SAPArrIdx prj' _) v i =
+ accumRef (STArr n t') (SAPArrIdx prj' _) v i = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip3 [0::Int ..]
+ (indexTupleComponents n (i++".a.a"))
+ (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ .||.
+ CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++
+ "); " ++
+ "abort();")
+ mempty
+
accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b")
let add :: STy a -> String -> String -> CompM ()
@@ -880,8 +899,9 @@ compile' env = \case
stmts1 mempty
add (STArr n t1) d s = do
shsizename <- genName' "acshsz"
- emit $ SVarDecl True "size_t" shsizename (compileShapeSize n (s++".a.b"))
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileShapeSize n (s++".a.b"))
ivar <- genName' "i"
+ -- TODO: emit check here for the source being either empty or equal in shape to the destination
((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $
stmts1
@@ -893,7 +913,7 @@ compile' env = \case
STBool -> error "Compile: accumulator add on booleans"
add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
- let dest = accumRef t prj (nameacc++".ac") nameidx
+ dest <- accumRef t prj (nameacc++".ac") nameidx
add (typeOf eval) dest nameval
incrementVarAlways Decrement (typeOf eval) nameval
@@ -977,10 +997,13 @@ incrementVar' inc path (ATArray (Some n) (Some eltty)) =
emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);"
Decrement -> do
case incrementVar Decrement eltty of
- Nothing -> do
- when debugRefc $
- emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
- emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
+ Nothing ->
+ if debugRefc
+ then do
+ emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");"
+ else do
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);"
Just f -> do
when debugRefc $
emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);"
@@ -1054,8 +1077,8 @@ compileArrShapeSize n var =
foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]")
| i <- [0 .. fromSNat n - 1]]
-compileShapeTupIntoArray :: SNat n -> String -> [CExpr]
-compileShapeTupIntoArray = \n var -> map CELit (toList (go n var))
+indexTupleComponents :: SNat n -> String -> [CExpr]
+indexTupleComponents = \n var -> map CELit (toList (go n var))
where
go :: SNat n -> String -> Bag String
go SZ _ = mempty
@@ -1130,8 +1153,7 @@ compileExtremum nameBase opName operator env e = do
let STArr (SS n) t = typeOf e
argname <- compileAssign (nameBase ++ "arg") env e
- when emitChecks $
- emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: " ++ opName ++ " got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }"
+ zeroRefcountCheck opName argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
@@ -1219,6 +1241,13 @@ copyForWriting topty var = case topty of
shszname <- genName' "shsz"
emit $ SVarDeclUninit (repSTy (STArr n t)) name
+ when debugShapes $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ emit $ SVerbatim $
+ "fprintf(stderr, \"[chad-kernel] with array " ++ shfmt ++ "\\n\"" ++
+ concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++
+ ");"
+
emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))
(pure (SAsg name (CELit var)))
(let shbytes = fromSNat n * 8
@@ -1272,6 +1301,14 @@ copyForWriting topty var = case topty of
STAccum _ -> error "Compile: Nested accumulators not supported"
+zeroRefcountCheck :: String -> String -> CompM ()
+zeroRefcountCheck opname arrvar =
+ when emitChecks $
+ emit $ SVerbatim $
+ "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++
+ "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }"
+
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id