From 16a836d078caefc3526031c084e2527cba0da3a8 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Wed, 26 Mar 2025 23:47:17 +0100
Subject: Fix various issues in Compile (still broken)

---
 src/Compile.hs | 196 +++++++++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 148 insertions(+), 48 deletions(-)

diff --git a/src/Compile.hs b/src/Compile.hs
index 5501746..00b90e3 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -240,7 +240,8 @@ genStruct name topty = case topty of
   STScal _ ->
     []
   STAccum t ->
-    [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com]
+    [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") ""
+    ,StructDecl name (name ++ "_buf *buf;") com]
   where
     com = ppSTy 0 topty
 
@@ -637,10 +638,8 @@ compile' env = \case
 
   EBuild _ n esh efun -> do
     shname <- compileAssign "sh" env esh
-    shsizename <- genName' "shsz"
-    emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname)
 
-    arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (indexTupleComponents n shname)
+    arrname <- allocArray Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname)
 
     idxargname <- genName' "ix"
     (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
@@ -671,7 +670,7 @@ compile' env = \case
     -- unexpected. But it's exactly what we want, so we do it anyway.
     emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname)
 
-    resname <- allocArray "foldres" n t (CELit shszname)
+    resname <- allocArray Malloc "foldres" n t (Just (CELit shszname))
                   [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
 
     lenname <- genName' "n"
@@ -715,7 +714,7 @@ compile' env = \case
     -- This n is one less than the shape of the thing we're querying, like EFold1Inner.
     emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
 
-    resname <- allocArray "sumres" n t (CELit shszname)
+    resname <- allocArray Malloc "sumres" n t (Just (CELit shszname))
                   [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
 
     lenname <- genName' "n"
@@ -757,8 +756,8 @@ compile' env = \case
     shszname <- genName' "shsz"
     emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
 
-    resname <- allocArray "rep" (SS n) t
-                 (CEBinop (CELit shszname) "*" (CELit lenname))
+    resname <- allocArray Malloc "rep" (SS n) t
+                 (Just (CEBinop (CELit shszname) "*" (CELit lenname)))
                  ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
                   ++ [CELit lenname])
 
@@ -850,28 +849,98 @@ compile' env = \case
 
     zeroRefcountCheck (typeOf e1) "with" name1
 
+    emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")"
     mcopy <- copyForWriting (CHAD.d2 t) name1
     accname <- genName' "accum"
-    emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])
+    emit $ SVarDecl False actyname accname
+              (CEStruct actyname [("buf", CECall "malloc" [CELit (show (sizeofSTy (CHAD.d2 t)))])])
+    emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy)
+    emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")."
 
     e2' <- compile' (Const accname `SCons` env) e2
 
+    resname <- genName' "acret"
+    emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac"))
+    emit $ SVerbatim $ "free(" ++ accname ++ ".buf);"
+
     rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))
-    return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")]
+    return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
 
   EAccum _ t prj eidx eval eacc -> do
     nameidx <- compileAssign "acidx" env eidx
     nameval <- compileAssign "acval" env eval
-    nameacc <- compileAssign "acac" env eacc
 
+    -- Generate the variable manually because this one has to be non-const.
+    eacc' <- compile' env eacc
+    nameacc <- genName' "acac"
+    emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
+
+    let -- Expects a variable reference to a value of type @D2 a@.
+        setZero :: STy a -> String -> CompM ()
+        setZero STNil _ = return ()
+        setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0")  -- Maybe (Pair (D2 a) (D2 b))
+        setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0")  -- Maybe (Either (D2 a) (D2 b))
+        setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0")  -- Maybe (D2 a)
+        setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0")  -- Maybe (Arr n (D2 a))
+        setZero (STScal sty) v = case sty of
+          STI32 -> return ()  -- Nil
+          STI64 -> return ()  -- Nil
+          STF32 -> emit $ SAsg v (CELit "0.0f")
+          STF64 -> emit $ SAsg v (CELit "0.0")
+          STBool -> return ()  -- Nil
+        setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported"
+
+        initD2Pair :: STy a -> STy b -> String -> CompM ()
+        initD2Pair a b v = do  -- Maybe (Pair (D2 a) (D2 b))
+          ((), stmts1) <- scope $ setZero a (v++".j.a")
+          ((), stmts2) <- scope $ setZero b (v++".j.b")
+          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+                   (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2)
+                   mempty
+
+        initD2Either :: STy a -> STy b -> String -> Either () () -> CompM ()
+        initD2Either a b v side = do  -- Maybe (Either (D2 a) (D2 b))
+          ((), stmts) <- case side of
+                           Left () -> scope $ setZero a (v++".j.l")
+                           Right () -> scope $ setZero b (v++".j.r")
+          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+                   (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
+                   mempty
+
+        initD2Maybe :: STy a -> String -> CompM ()
+        initD2Maybe a v = do  -- Maybe (D2 a)
+          ((), stmts) <- scope $ setZero a (v++".j")
+          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+                   (pure (SAsg (v++".tag") (CELit "1")) <> stmts)
+                   mempty
+
+    -- mind: this has to traverse the D2 of these things, and it also has to
+    -- initialise data structures that are still sparse in the accumulator.
     let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String
         accumRef _ SAPHere v _ = pure v
-        accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i
-        accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i
-        accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i
-        accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i
-        accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i
+        accumRef (STPair ta tb) (SAPFst prj') v i = do
+          initD2Pair ta tb v
+          accumRef ta prj' (v++".j.a") i
+        accumRef (STPair ta tb) (SAPSnd prj') v i = do
+          initD2Pair ta tb v
+          accumRef tb prj' (v++".j.b") i
+        accumRef (STEither ta tb) (SAPLeft prj') v i = do
+          initD2Either ta tb v (Left ())
+          accumRef ta prj' (v++".j.l") i
+        accumRef (STEither ta tb) (SAPRight prj') v i = do
+          initD2Either ta tb v (Right ())
+          accumRef tb prj' (v++".j.r") i
+        accumRef (STMaybe tj) (SAPJust prj') v i = do
+          initD2Maybe tj v
+          accumRef tj prj' (v++".j") i
         accumRef (STArr n t') (SAPArrIdx prj' _) v i = do
+          (newarrName, newarrStmts) <- scope $ allocArray Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b"))
+          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+                   (pure (SAsg (v++".tag") (CELit "1"))
+                    <> newarrStmts
+                    <> pure (SAsg (v++".j") (CELit newarrName)))
+                   mempty
+
           when emitChecks $ do
             let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
             forM_ (zip3 [0::Int ..]
@@ -880,58 +949,77 @@ compile' env = \case
               let a .||. b = CEBinop a "||" b
               emit $ SIf (CEBinop ixcomp "<" (CELit "0")
                           .||.
-                          CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+                          CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))
                           .||.
-                          CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
+                          CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))))
                        (pure $ SVerbatim $
                           "fprintf(stderr, \"[chad-kernel] CHECK: accum prj incorrect (arr=%p, " ++
                           "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
-                          v ++ ".buf" ++
-                          concat [", " ++ v ++ ".buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
+                          v ++ ".j.buf" ++
+                          concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++
                           concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
                           concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++
                           "); " ++
                           "abort();")
                        mempty
 
-          accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b")
+          accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b")
 
+    -- mind: this has to add the D2 of these things, and it also has to
+    -- initialise data structures that are still sparse in the accumulator.
     let add :: STy a -> String -> String -> CompM ()
         add STNil _ _ = return ()
         add (STPair t1 t2) d s = do
-          add t1 (d++".a") (s++".a")
-          add t2 (d++".b") (s++".b")
+          ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a")
+          ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b")
+          emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+                   (pure (SAsg d (CELit s)))
+                   (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+                            (stmts1 <> stmts2)
+                            mempty))
         add (STEither t1 t2) d s = do
-          ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
-          ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
-          emit $ SAsg (d++".tag") (CELit (s++".tag"))
-          emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0"))
-                     stmts1 stmts2
+          ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l")
+          ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r")
+          emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+                   (pure (SAsg d (CELit s)))
+                   (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+                            (pure (SAsg (d++".j.tag") (CELit (s++".j.tag")))
+                             <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0"))
+                                        stmts1 stmts2))
+                            mempty))
         add (STMaybe t1) d s = do
           ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
-          emit $ SAsg (d++".tag") (CELit (s++".tag"))
-          emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
-                     stmts1 mempty
+          emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+                   (pure (SAsg d (CELit s)))
+                   (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+                            (pure (SAsg (d++".tag") (CELit "1")) <> stmts1)
+                            mempty))
         add (STArr n t1) d s = do
           shsizename <- genName' "acshsz"
-          emit $ SVarDecl True (repSTy tIx) shsizename (compileShapeSize n (s++".a.b"))
           ivar <- genName' "i"
-          -- TODO: emit check here for the source being either empty or equal in shape to the destination
-          ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
-          emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $
-                   stmts1
+          ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]")
+          emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+                   (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0"))
+                            (pure (SAsg d (CELit s)))
+                            (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")))
+                             -- TODO: emit check here for the source being either equal in shape to the destination
+                             <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
+                                        stmts1))))
+                   mempty
         add (STScal sty) d s = case sty of
-          STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
-          STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+          STI32 -> return ()
+          STI64 -> return ()
           STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
           STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
-          STBool -> error "Compile: accumulator add on booleans"
+          STBool -> return ()
         add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
 
-    dest <- accumRef t prj (nameacc++".ac") nameidx
-    add (typeOf eval) dest nameval
+    emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
+    dest <- accumRef t prj (nameacc++".buf->ac") nameidx
+    add (acPrjTy prj t) dest nameval
 
     incrementVarAlways Decrement (typeOf eval) nameval
+    emit $ SVerbatim $ "// compile EAccum end"
 
     return $ CEStruct (repSTy STNil) []
 
@@ -998,7 +1086,7 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
                                          (smartATProj "b" (makeArrayTree b))
 makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
                                               (smartATProj "r" (makeArrayTree b))
-makeArrayTree (STMaybe t) = smartATCondTag ATNoop (makeArrayTree t)
+makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
 makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
 makeArrayTree (STScal _) = ATNoop
 makeArrayTree (STAccum _) = ATNoop
@@ -1054,18 +1142,26 @@ toLinearIdx (SS n) arrvar idxvar =
 --   emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
 --   _
 
+data AllocMethod = Malloc | Calloc
+  deriving (Show)
+
 -- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String
-allocArray nameBase rank eltty shsz shape = do
+allocArray :: AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray method nameBase rank eltty mshsz shape = do
   when (length shape /= fromSNat rank) $
     error "allocArray: shape does not match rank"
   let arrty = STArr rank eltty
   strname <- emitStruct arrty
   arrname <- genName' nameBase
+  shsz <- case mshsz of
+            Just e -> return e
+            Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape)
+  let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8)))
+                           "+"
+                           (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))
   emit $ SVarDecl True strname arrname $ CEStruct strname
-            [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8)))
-                                              "+"
-                                              (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])]
+            [("buf", case method of Malloc -> CECall "malloc" [nbytesExpr]
+                                    Calloc -> CECall "calloc" [nbytesExpr, CELit "1"])]
   forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
     emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
   emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
@@ -1175,7 +1271,7 @@ compileExtremum nameBase opName operator env e = do
   -- unexpected. But it's exactly what we want, so we do it anyway.
   emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
 
-  resname <- allocArray (nameBase ++ "res") n t (CELit shszname)
+  resname <- allocArray Malloc (nameBase ++ "res") n t (Just (CELit shszname))
                 [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
 
   lenname <- genName' "n"
@@ -1369,3 +1465,7 @@ showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "
 -- | Type-restricted.
 (^) :: Num a => a -> Int -> a
 (^) = (Prelude.^)
+
+foldl0' :: (a -> a -> a) -> a -> [a] -> a
+foldl0' _ x [] = x
+foldl0' f _ l = foldl1' f l
-- 
cgit v1.2.3-70-g09d2