diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-23 11:59:17 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-23 11:59:17 +0100 |
commit | fa1906183e91a3f0fadd27a29375b860ac40e53c (patch) | |
tree | 6c18b3c064e2e2d5c6d8dd549d59303f434bc997 | |
parent | b87518c60f3034411bffc0c4745141db6a8d81d3 (diff) |
Compile: More checkstest-compile-chad
-rw-r--r-- | src/Compile.hs | 123 |
1 files changed, 80 insertions, 43 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 3cc8934..09c3ed5 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -42,26 +42,19 @@ import Data import Interpreter.Rep -{- -:m *Example Compile AST.UnMonoid -:seti -XOverloadedLabels -XGADTs -let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i :: Double) in (($ SCons (Value array) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ #x ! pair nil (round_ (#x ! pair nil 3)) -(($ SNil) =<<) $ compile knownEnv $ fromNamed $ body $ build2 5 3 (#i :-> #j :-> 10 * #i + #j) --} - - -- In shape and index arrays, the innermost dimension is on the right (last index). -- TODO: test that I'm properly incrementing and decrementing refcounts in all required places -debugCSource, debugRefc, emitChecks :: Bool -- | Print the generated C source -debugCSource = toEnum 0 +debugCSource :: Bool; debugCSource = toEnum 0 -- | Print extra stuff about reference counts of arrays -debugRefc = toEnum 1 +debugRefc :: Bool; debugRefc = toEnum 0 +-- | Print some shape-related information +debugShapes :: Bool; debugShapes = toEnum 0 -- | Emit extra C code that checks stuff -emitChecks = toEnum 1 +emitChecks :: Bool; emitChecks = toEnum 0 compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) @@ -331,18 +324,19 @@ compileToString env expr = in ($ "") $ compose [showString "#include <stdio.h>\n" ,showString "#include <stdint.h>\n" + ,showString "#include <inttypes.h>\n" ,showString "#include <stdlib.h>\n" ,showString "#include <string.h>\n" ,showString "#include <math.h>\n\n" - ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs + ,compose [printStructDecl sd . showString "\n" | sd <- structs] ,showString "\n" ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] ,showString $ "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++ intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ ") {\n" - ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) - ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n" + ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] + ,showString " return " . printCExpr 0 res . showString ";\n}\n\n" ,showString "void kernel(void *data) {\n" -- Some code here assumes that we're on a 64-bit system, so let's check that ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" @@ -631,7 +625,7 @@ compile' env = \case shsizename <- genName' "shsz" emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname) - arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname) + arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (indexTupleComponents n shname) idxargname <- genName' "ix" (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun @@ -655,8 +649,7 @@ compile' env = \case x0name <- compileAssign "foldx0" env ex0 arrname <- compileAssign "foldarr" env earr - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: fold1i got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" + zeroRefcountCheck "fold1i" arrname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is @@ -701,8 +694,7 @@ compile' env = \case let STArr (SS n) t = typeOf e argname <- compileAssign "sumarg" env e - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: sum1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" + zeroRefcountCheck "sum1i" argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, like EFold1Inner. @@ -745,8 +737,7 @@ compile' env = \case lenname <- compileAssign "replen" env elen argname <- compileAssign "reparg" env earg - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: replicate1i got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" + zeroRefcountCheck "replicate1i" argname shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -776,8 +767,7 @@ compile' env = \case EIdx0 _ e -> do let STArr _ t = typeOf e arrname <- compileAssign "" env e - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx0 got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" + zeroRefcountCheck "idx0" arrname name <- genName emit $ SVarDecl True (repSTy t) name (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) @@ -789,11 +779,20 @@ compile' env = \case EIdx _ earr eidx -> do let STArr n t = typeOf earr arrname <- compileAssign "ixarr" env earr - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ arrname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: idx got array %p with refc=0\\n\", " ++ arrname ++ ".buf); abort(); }" + zeroRefcountCheck "idx" arrname idxname <- if fromSNat n > 0 -- prevent an unused-varable warning then compileAssign "ixix" env eidx else return "" -- won't be used in this case + + when emitChecks $ + forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> + emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (pure $ SVerbatim $ + "fprintf(stderr, \"[chad-kernel] CHECK: index out of range (arr=%p)\\n\", " ++ + arrname ++ ".buf); abort();") + mempty + resname <- genName' "ixres" emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname)) incrementVarAlways Decrement (STArr n t) arrname @@ -804,8 +803,7 @@ compile' env = \case t = tTup (sreplicate n tIx) _ <- emitStruct t name <- compileAssign "" env e - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ name ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: shape got array %p with refc=0\\n\", " ++ name ++ ".buf); abort(); }" + zeroRefcountCheck "shape" name resname <- genName emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) incrementVarAlways Decrement (typeOf e) name @@ -835,8 +833,7 @@ compile' env = \case actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ name1 ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: with got array %p with refc=0\\n\", " ++ name1 ++ ".buf); abort(); }" + zeroRefcountCheck "with" name1 mcopy <- copyForWriting t name1 accname <- genName' "accum" @@ -852,14 +849,36 @@ compile' env = \case nameval <- compileAssign "acval" env eval nameacc <- compileAssign "acac" env eacc - let accumRef :: STy a -> SAcPrj p a b -> String -> String -> String - accumRef _ SAPHere v _ = v + let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String + accumRef _ SAPHere v _ = pure v accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i - accumRef (STArr n t') (SAPArrIdx prj' _) v i = + accumRef (STArr n t') (SAPArrIdx prj' _) v i = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ (zip3 [0::Int ..] + (indexTupleComponents n (i++".a.a")) + (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + let a .||. b = CEBinop a "||" b + emit $ SIf (CEBinop ixcomp "<" (CELit "0") + .||. + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + .||. + CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) + (pure $ SVerbatim $ + "fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++ + "); " ++ + "abort();") + mempty + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") let add :: STy a -> String -> String -> CompM () @@ -880,8 +899,9 @@ compile' env = \case stmts1 mempty add (STArr n t1) d s = do shsizename <- genName' "acshsz" - emit $ SVarDecl True "size_t" shsizename (compileShapeSize n (s++".a.b")) + emit $ SVarDecl True (repSTy tIx) shsizename (compileShapeSize n (s++".a.b")) ivar <- genName' "i" + -- TODO: emit check here for the source being either empty or equal in shape to the destination ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $ stmts1 @@ -893,7 +913,7 @@ compile' env = \case STBool -> error "Compile: accumulator add on booleans" add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" - let dest = accumRef t prj (nameacc++".ac") nameidx + dest <- accumRef t prj (nameacc++".ac") nameidx add (typeOf eval) dest nameval incrementVarAlways Decrement (typeOf eval) nameval @@ -977,10 +997,13 @@ incrementVar' inc path (ATArray (Some n) (Some eltty)) = emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p in+ -> %zu\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" Decrement -> do case incrementVar Decrement eltty of - Nothing -> do - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" - emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" + Nothing -> + if debugRefc + then do + emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" + else do + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);" Just f -> do when debugRefc $ emit $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] arr %p de- -> %zu recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" @@ -1054,8 +1077,8 @@ compileArrShapeSize n var = foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] -compileShapeTupIntoArray :: SNat n -> String -> [CExpr] -compileShapeTupIntoArray = \n var -> map CELit (toList (go n var)) +indexTupleComponents :: SNat n -> String -> [CExpr] +indexTupleComponents = \n var -> map CELit (toList (go n var)) where go :: SNat n -> String -> Bag String go SZ _ = mempty @@ -1130,8 +1153,7 @@ compileExtremum nameBase opName operator env e = do let STArr (SS n) t = typeOf e argname <- compileAssign (nameBase ++ "arg") env e - when emitChecks $ - emit $ SVerbatim $ "if (__builtin_expect(" ++ argname ++ ".buf->refc == 0, 0)) { fprintf(stderr, \"[chad-kernel] CHECK: " ++ opName ++ " got array %p with refc=0\\n\", " ++ argname ++ ".buf); abort(); }" + zeroRefcountCheck opName argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is @@ -1219,6 +1241,13 @@ copyForWriting topty var = case topty of shszname <- genName' "shsz" emit $ SVarDeclUninit (repSTy (STArr n t)) name + when debugShapes $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + emit $ SVerbatim $ + "fprintf(stderr, \"[chad-kernel] with array " ++ shfmt ++ "\\n\"" ++ + concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + ");" + emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) (pure (SAsg name (CELit var))) (let shbytes = fromSNat n * 8 @@ -1272,6 +1301,14 @@ copyForWriting topty var = case topty of STAccum _ -> error "Compile: Nested accumulators not supported" +zeroRefcountCheck :: String -> String -> CompM () +zeroRefcountCheck opname arrvar = + when emitChecks $ + emit $ SVerbatim $ + "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++ + "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ + "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }" + compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id |