diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 23:47:17 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 23:47:27 +0100 |
commit | 16a836d078caefc3526031c084e2527cba0da3a8 (patch) | |
tree | 7ca6a28063e259823750ee37bd4bebae27c5408d /src/Compile.hs | |
parent | f1e867838db63da71fea660740c23ab276a43a6c (diff) |
Fix various issues in Compile (still broken)
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 196 |
1 files changed, 148 insertions, 48 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 5501746..00b90e3 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -240,7 +240,8 @@ genStruct name topty = case topty of STScal _ -> [] STAccum t -> - [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com] + [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] where com = ppSTy 0 topty @@ -637,10 +638,8 @@ compile' env = \case EBuild _ n esh efun -> do shname <- compileAssign "sh" env esh - shsizename <- genName' "shsz" - emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname) - arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (indexTupleComponents n shname) + arrname <- allocArray Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname) idxargname <- genName' "ix" (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun @@ -671,7 +670,7 @@ compile' env = \case -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) - resname <- allocArray "foldres" n t (CELit shszname) + resname <- allocArray Malloc "foldres" n t (Just (CELit shszname)) [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] lenname <- genName' "n" @@ -715,7 +714,7 @@ compile' env = \case -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "sumres" n t (CELit shszname) + resname <- allocArray Malloc "sumres" n t (Just (CELit shszname)) [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] lenname <- genName' "n" @@ -757,8 +756,8 @@ compile' env = \case shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "rep" (SS n) t - (CEBinop (CELit shszname) "*" (CELit lenname)) + resname <- allocArray Malloc "rep" (SS n) t + (Just (CEBinop (CELit shszname) "*" (CELit lenname))) ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] ++ [CELit lenname]) @@ -850,28 +849,98 @@ compile' env = \case zeroRefcountCheck (typeOf e1) "with" name1 + emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" mcopy <- copyForWriting (CHAD.d2 t) name1 accname <- genName' "accum" - emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)]) + emit $ SVarDecl False actyname accname + (CEStruct actyname [("buf", CECall "malloc" [CELit (show (sizeofSTy (CHAD.d2 t)))])]) + emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 + resname <- genName' "acret" + emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac")) + emit $ SVerbatim $ "free(" ++ accname ++ ".buf);" + rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t)) - return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")] + return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] EAccum _ t prj eidx eval eacc -> do nameidx <- compileAssign "acidx" env eidx nameval <- compileAssign "acval" env eval - nameacc <- compileAssign "acac" env eacc + -- Generate the variable manually because this one has to be non-const. + eacc' <- compile' env eacc + nameacc <- genName' "acac" + emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' + + let -- Expects a variable reference to a value of type @D2 a@. + setZero :: STy a -> String -> CompM () + setZero STNil _ = return () + setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b)) + setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b)) + setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a) + setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a)) + setZero (STScal sty) v = case sty of + STI32 -> return () -- Nil + STI64 -> return () -- Nil + STF32 -> emit $ SAsg v (CELit "0.0f") + STF64 -> emit $ SAsg v (CELit "0.0") + STBool -> return () -- Nil + setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported" + + initD2Pair :: STy a -> STy b -> String -> CompM () + initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b)) + ((), stmts1) <- scope $ setZero a (v++".j.a") + ((), stmts2) <- scope $ setZero b (v++".j.b") + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2) + mempty + + initD2Either :: STy a -> STy b -> String -> Either () () -> CompM () + initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b)) + ((), stmts) <- case side of + Left () -> scope $ setZero a (v++".j.l") + Right () -> scope $ setZero b (v++".j.r") + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmts) + mempty + + initD2Maybe :: STy a -> String -> CompM () + initD2Maybe a v = do -- Maybe (D2 a) + ((), stmts) <- scope $ setZero a (v++".j") + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmts) + mempty + + -- mind: this has to traverse the D2 of these things, and it also has to + -- initialise data structures that are still sparse in the accumulator. let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String accumRef _ SAPHere v _ = pure v - accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i - accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i - accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i - accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i - accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i + accumRef (STPair ta tb) (SAPFst prj') v i = do + initD2Pair ta tb v + accumRef ta prj' (v++".j.a") i + accumRef (STPair ta tb) (SAPSnd prj') v i = do + initD2Pair ta tb v + accumRef tb prj' (v++".j.b") i + accumRef (STEither ta tb) (SAPLeft prj') v i = do + initD2Either ta tb v (Left ()) + accumRef ta prj' (v++".j.l") i + accumRef (STEither ta tb) (SAPRight prj') v i = do + initD2Either ta tb v (Right ()) + accumRef tb prj' (v++".j.r") i + accumRef (STMaybe tj) (SAPJust prj') v i = do + initD2Maybe tj v + accumRef tj prj' (v++".j") i accumRef (STArr n t') (SAPArrIdx prj' _) v i = do + (newarrName, newarrStmts) <- scope $ allocArray Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b")) + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) + <> newarrStmts + <> pure (SAsg (v++".j") (CELit newarrName))) + mempty + when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ (zip3 [0::Int ..] @@ -880,58 +949,77 @@ compile' env = \case let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))) .||. - CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) + CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ - v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++ + v ++ ".j.buf" ++ + concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++ "); " ++ "abort();") mempty - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") + accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b") + -- mind: this has to add the D2 of these things, and it also has to + -- initialise data structures that are still sparse in the accumulator. let add :: STy a -> String -> String -> CompM () add STNil _ _ = return () add (STPair t1 t2) d s = do - add t1 (d++".a") (s++".a") - add t2 (d++".b") (s++".b") + ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a") + ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s))) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + (stmts1 <> stmts2) + mempty)) add (STEither t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") - ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") - emit $ SAsg (d++".tag") (CELit (s++".tag")) - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0")) - stmts1 stmts2 + ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l") + ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s))) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + (pure (SAsg (d++".j.tag") (CELit (s++".j.tag"))) + <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0")) + stmts1 stmts2)) + mempty)) add (STMaybe t1) d s = do ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") - emit $ SAsg (d++".tag") (CELit (s++".tag")) - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - stmts1 mempty + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s))) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + (pure (SAsg (d++".tag") (CELit "1")) <> stmts1) + mempty)) add (STArr n t1) d s = do shsizename <- genName' "acshsz" - emit $ SVarDecl True (repSTy tIx) shsizename (compileShapeSize n (s++".a.b")) ivar <- genName' "i" - -- TODO: emit check here for the source being either empty or equal in shape to the destination - ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $ - stmts1 + ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]") + emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s))) + (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))) + -- TODO: emit check here for the source being either equal in shape to the destination + <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1)))) + mempty add (STScal sty) d s = case sty of - STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + STI32 -> return () + STI64 -> return () STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STBool -> error "Compile: accumulator add on booleans" + STBool -> return () add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" - dest <- accumRef t prj (nameacc++".ac") nameidx - add (typeOf eval) dest nameval + emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" + dest <- accumRef t prj (nameacc++".buf->ac") nameidx + add (acPrjTy prj t) dest nameval incrementVarAlways Decrement (typeOf eval) nameval + emit $ SVerbatim $ "// compile EAccum end" return $ CEStruct (repSTy STNil) [] @@ -998,7 +1086,7 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) -makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t) +makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop @@ -1054,18 +1142,26 @@ toLinearIdx (SS n) arrvar idxvar = -- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) -- _ +data AllocMethod = Malloc | Calloc + deriving (Show) + -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String -allocArray nameBase rank eltty shsz shape = do +allocArray :: AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" let arrty = STArr rank eltty strname <- emitStruct arrty arrname <- genName' nameBase + shsz <- case mshsz of + Just e -> return e + Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape) + let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8))) + "+" + (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname - [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8))) - "+" - (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])] + [("buf", case method of Malloc -> CECall "malloc" [nbytesExpr] + Calloc -> CECall "calloc" [nbytesExpr, CELit "1"])] forM_ (zip shape [0::Int ..]) $ \(dim, i) -> emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") @@ -1175,7 +1271,7 @@ compileExtremum nameBase opName operator env e = do -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray (nameBase ++ "res") n t (CELit shszname) + resname <- allocArray Malloc (nameBase ++ "res") n t (Just (CELit shszname)) [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] lenname <- genName' "n" @@ -1369,3 +1465,7 @@ showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) " -- | Type-restricted. (^) :: Num a => a -> Int -> a (^) = (Prelude.^) + +foldl0' :: (a -> a -> a) -> a -> [a] -> a +foldl0' _ x [] = x +foldl0' f _ l = foldl1' f l |