From 6da98aedf2f28ec8848d1cb8f5605b0c7e64d644 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sat, 15 Mar 2025 11:32:34 +0100
Subject: Complete accumulator revamp!

---
 src/Interpreter.hs | 320 +++++++++++++----------------------------------------
 1 file changed, 78 insertions(+), 242 deletions(-)

(limited to 'src/Interpreter.hs')

diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 11caac0..11184c9 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE BangPatterns #-}
 {-# LANGUAGE DataKinds #-}
 {-# LANGUAGE DerivingStrategies #-}
 {-# LANGUAGE FlexibleContexts #-}
@@ -19,12 +20,10 @@ module Interpreter (
   Value(..),
 ) where
 
-import Control.Monad (foldM, join, when)
-import Data.Bifunctor (bimap)
+import Control.Monad (foldM, join, when, forM_)
 import Data.Bitraversable (bitraverse)
 import Data.Char (isSpace)
 import Data.Functor.Identity
-import Data.Kind (Type)
 import Data.Int (Int64)
 import Data.IORef
 import System.IO (hPutStrLn, stderr)
@@ -64,8 +63,9 @@ interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value en
 interpret' env e = do
   let dep = ?depth
   let lenlimit = max 20 (100 - dep)
-  let trunc s | length s > lenlimit = take (lenlimit - 3) s ++ "..."
-              | otherwise           = s
+  let replace a b = map (\c -> if c == a then b else c)
+  let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
+              | otherwise           = replace '\n' ' ' s
   when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
   res <- let ?depth = dep + 1 in interpret'Rec env e
   when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
@@ -243,7 +243,7 @@ onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
 withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
 withAccum t _ initval f = AcM $ do
   accum <- newAcSparse t SAPHere () initval
-  out <- case f accum of AcM m -> m
+  out <- unAcM $ f accum
   val <- readAcSparse t accum
   return (out, val)
 
@@ -262,59 +262,6 @@ newAcZero = \case
     STBool -> return ()
   STAccum{} -> error "Nested accumulators"
 
--- | Inverted index: the outermost index is at the /outside/ of this list.
-data PartialInvIndex n m where
-  PIIxEnd :: PartialInvIndex m m
-  PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m
-
--- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list.
-data Inverted (f :: Nat -> Type) n where
-  InvNil :: Inverted f Z
-  InvCons :: Int -> Inverted f n -> Inverted f (S n)
-
-type InvShape = Inverted Shape
-type InvIndex = Inverted Index
-
-class Shapey f where
-  shapeyNil :: f Z
-  shapeyCons :: f n -> Int -> f (S n)
-  shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r
-instance Shapey Index where
-  shapeyNil = IxNil
-  shapeyCons = IxCons
-  shapeyCase IxNil k0 _ = k0
-  shapeyCase (IxCons idx i) _ k1 = k1 idx i
-instance Shapey Shape where
-  shapeyNil = ShNil
-  shapeyCons = ShCons
-  shapeyCase ShNil k0 _ = k0
-  shapeyCase (ShCons sh n) _ k1 = k1 sh n
-
-invert :: forall f n. Shapey f => f n -> Inverted f n
-invert | Refl <- lemPlusZero @n = flip go InvNil
-  where
-    go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m)
-    go sh ish = shapeyCase sh
-                  ish
-                  (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish))
-
-uninvert :: forall f n. Shapey f => Inverted f n -> f n
-uninvert = go shapeyNil
-  where
-    go :: forall n' m. f n' -> Inverted f m -> f (n' + m)
-    go sh InvNil | Refl <- lemPlusZero @n' = sh
-    go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish
-
-piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m)
-piindexMatch PIIxEnd ix = Just ix
-piindexMatch (PIIxCons i pix) (InvCons i' ix)
-  | i == i' = piindexMatch pix ix
-  | otherwise = Nothing
-
-piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
-piindexConcat PIIxEnd ix = ix
-piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-
 newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)
 newAcSparse typ prj idx val = case (typ, prj) of
   (STNil, SAPHere) -> return ()
@@ -353,27 +300,8 @@ onehotArray :: Monad m
 onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =
   let arrindex = unTupRepIdx IxNil IxCons n arrindex'
       arrsh = unTupRepIdx ShNil ShCons n arrsh'
-  in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero)
-
--- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a))
--- newAcDense typ SZ () val = case typ of
---   STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
---   STEither t1 t2 -> case val of
---     Left x -> Left <$> newAcSparse t1 SZ () x
---     Right y -> Right <$> newAcSparse t2 SZ () y
---   _ -> error "newAcDense: invalid dense type"
--- newAcDense typ (SS dep) idx val = case typ of
---   STPair t1 t2 ->
---     case (idx, val) of
---       (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
---       (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
---       _ -> error "Index/value mismatch in newAc pair"
---   STEither t1 t2 ->
---     case (idx, val) of
---       (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
---       (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
---       _ -> error "Index/value mismatch in newAc either"
---   _ -> error "newAcDense: invalid dense type"
+      !linindex = toLinearIndex arrsh arrindex
+  in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero)
 
 readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))
 readAcSparse typ val = case typ of
@@ -390,14 +318,6 @@ readAcSparse typ val = case typ of
     STBool -> return ()
   STAccum{} -> error "Nested accumulators"
 
--- readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
--- readAcDense typ val = case typ of
---   STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
---   STEither t1 t2 -> case val of
---     Left x -> Left <$> readAcSparse t1 x
---     Right y -> Right <$> readAcSparse t2 y
---   _ -> error "readAcDense: invalid dense type"
-
 accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()
 accumAddSparse typ prj ref idx val = case (typ, prj) of
   (STNil, SAPHere) -> return ()
@@ -406,26 +326,66 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
     case val of
       Nothing -> return ()
       Just (val1, val2) ->
-        AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
-                                          <*> newAcSparse t2 SAPHere () val2)
-                                     (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1
-                                                        unAcM $ accumAddSparse t2 SAPHere ac2 () val2)
+        realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
+                                    <*> newAcSparse t2 SAPHere () val2)
+                               (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
+                                                  accumAddSparse t2 SAPHere ac2 () val2)
   (STPair t1 t2, SAPFst prj') ->
-    AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
-                                 (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val)
+    realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
+                           (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
   (STPair t1 t2, SAPSnd prj') ->
-    AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
-                                 (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val)
-
-  (STEither t1 t2, SAPHere) -> _ ref val
-  (STEither t1 _, SAPLeft prj') -> _ ref idx val
-  (STEither _ t2, SAPRight prj') -> _ ref idx val
-
-  (STMaybe t1, SAPHere) -> _ ref val
-  (STMaybe t1, SAPJust prj') -> _ ref idx val
+    realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
+                           (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
 
-  (STArr _ t1, SAPHere) -> _ ref val
-  (STArr n t, SAPArrIdx prj' _) -> _ ref idx val
+  (STEither{}, SAPHere) ->
+    case val of
+      Nothing -> return ()
+      Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
+      Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
+  (STEither t1 _, SAPLeft prj') ->
+    realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
+                           (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
+                                  Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+  (STEither _ t2, SAPRight prj') ->
+    realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
+                           (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
+                                  Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+
+  (STMaybe{}, SAPHere) ->
+    case val of
+      Nothing -> return ()
+      Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
+  (STMaybe t1, SAPJust prj') ->
+      realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+                             (\ac -> accumAddSparse t1 prj' ac idx val)
+
+  (STArr _ t1, SAPHere) ->
+    let add ac = forM_ [0 .. arraySize ac - 1] $ \i ->
+                   unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i)
+    in if arraySize val == 0
+         then return ()
+         else AcM $ join $ atomicModifyIORef' ref $ \ac ->
+                if arraySize ac == 0
+                  then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val
+                               join $ atomicModifyIORef' ref $ \ac' ->
+                                 if arraySize ac == 0
+                                   then (newac, return ())
+                                   else (ac', add ac'))
+                  else (ac, add ac)
+  (STArr n t1, SAPArrIdx prj' _) ->
+    let ((arrindex', arrsh'), idx') = idx
+        arrindex = unTupRepIdx IxNil IxCons n arrindex'
+        arrsh = unTupRepIdx ShNil ShCons n arrsh'
+        linindex = toLinearIndex arrsh arrindex
+        add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val
+    in AcM $ join $ atomicModifyIORef' ref $ \ac ->
+         if arraySize ac == 0
+           then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx
+                        join $ atomicModifyIORef' ref $ \ac' ->
+                          if arraySize ac == 0
+                            then (newac, return ())
+                            else (ac', add ac'))
+           else (ac, add ac)
 
   (STScal sty, SAPHere) -> AcM $ case sty of
     STI32 -> return ()
@@ -436,145 +396,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
 
   (STAccum{}, _) -> error "Accumulators not allowed in source program"
 
-realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO ()
+realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
 realiseMaybeSparse ref makeval modifyval =
   -- Try modifying what's already in ref. The 'join' makes the snd
   -- of the function's return value a _continuation_ that is run after
   -- the critical section ends.
-  join $ atomicModifyIORef' ref $ \ac -> case ac of
-           -- Oops, ref's contents was still sparse. Have to initialise
-           -- it first, then try again.
-           Nothing -> (ac, do val <- makeval
-                              join $ atomicModifyIORef' ref $ \ac' -> case ac' of
-                                       Nothing -> (Just val, return ())
-                                       Just val' -> (ac', modifyval val'))
-           -- Yep, ref already had a value in there, so we can just add
-           -- val' to it recursively.
-           Just val -> (ac, modifyval val)
-
-{-
-accumAddSparse typ SZ ref () val = case typ of
-  STNil -> return ()
-  STPair t1 t2 -> AcM $ do
-    (r1, r2) <- readIORef ref
-    unAcM $ accumAddSparse t1 SZ r1 () (fst val)
-    unAcM $ accumAddSparse t2 SZ r2 () (snd val)
-  STMaybe t ->
-    case val of
-      Nothing -> return ()
-      Just val' ->
-        -- Try adding val' to what's already in ref. The 'join' makes the snd
-        -- of the function's return value a _continuation_ that is run after
-        -- the critical section ends.
-        AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
-                       -- Oops, ref's contents was still sparse. Have to initialise
-                       -- it first, then try again.
-                       Nothing -> (ac, do newac <- newAcDense t SZ () val'
-                                          join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
-                                                   Nothing -> (Just newac, return ())
-                                                   Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val'))
-                       -- Yep, ref already had a value in there, so we can just add
-                       -- val' to it recursively.
-                       Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val')
-  STArr _ t -> AcM $ do
-    refs <- readIORef ref
-    case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
-      (_, 0) -> return ()
-      (0, _) -> do
-        newrefarr <- traverse (newAcSparse t SZ ()) val
-        join $ atomicModifyIORef' ref $ \refarr ->
-                 if shapeSize (arrayShape refarr) == 0
-                   then (newrefarr, return ())
-                   else -- someone was faster than us in initialising the reference!
-                        (refarr, unAcM $ accumAddSparse typ SZ ref () val)  -- just try again from the start (dropping newrefarr for the GC to clean up)
-      _ | arrayShape refs == arrayShape val ->
-            sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
-                      | i <- [0 .. shapeSize (arrayShape val) - 1]]
-        | otherwise -> error "Array shape mismatch in accum add"
-  STScal sty -> AcM $ case sty of
-    STI32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
-    STI64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
-    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
-    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
-    STBool -> error "Accumulator of Bool"
-  STAccum{} -> error "Nested accumulators"
-  STEither{} -> error "Bare Either in accumulator"
-
-accumAddSparse typ (SS dep) ref idx val = case typ of
-  STNil -> return ()
-  STPair t1 t2 -> AcM $ do
-    (ref1, ref2) <- readIORef ref
-    case (idx, val) of
-      (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val'
-      (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val'
-      _ -> error "Index/value mismatch in pair accumulator add"
-  STMaybe t ->
-    AcM $ join $ atomicModifyIORef' ref $ \case
-                   -- Oops, ref's contents was still sparse. Have to initialise
-                   -- it first, then try again.
-                   Nothing -> (Nothing, do newac <- newAcDense t dep idx val
-                                           join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
-                                                    Nothing -> (Just newac, return ())
-                                                    Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val))
-                   -- Yep, ref already had a value in there, so we can just add
-                   -- val' to it recursively.
-                   Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
-  STArr dim (t :: STy t) -> AcM $ do
-    refs <- readIORef ref
-    if shapeSize (arrayShape refs) == 0
-      then do newrefarr <- newAcArray dim t (SS dep) idx val
-              join $ atomicModifyIORef' ref $ \refarr ->
-                       if shapeSize (arrayShape refarr) == 0
-                         then (newrefarr, return ())
-                         else -- someone was faster than us in initialising the reference!
-                              (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val)  -- just try again from the start (dropping newrefarr for the GC to clean up)
-      else do let sh = unTupRepIdx ShNil ShCons dim (fst val)
-              go (SS dep) (invert sh) idx (snd val)
-                (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj)
-                (\piix subsh val' -> unAcM $ sequence_
-                                       [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix)))
-                                                       () (val' `arrayIndex` subix)
-                                       | subix <- enumShape subsh])
-    where
-      go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
-         -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r)  -- ^ Indexing into element of the array
-         -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r)  -- ^ Accumulating onto a subarray
-         -> r
-      go SZ        ish             ()        val' _  k0 = k0 PIIxEnd (uninvert ish) val'  -- ^ Ran out of AcIdx: accumulating onto subarray
-      go (SS dep') InvNil          idx'      val' kj _  = kj dep' IxNil idx' val'  -- ^ Ran out of array dimensions: accumulating into (part of) element
-      go (SS dep') (InvCons _ ish) (i, idx') val' kj k0 =
-        go dep' ish idx' val'
-          (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj)
-          (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm)
-  STScal{} -> error "Cannot index into scalar"
-  STAccum{} -> error "Nested accumulators"
-  STEither{} -> error "Bare Either in accumulator"
-
-accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
-accumAddDense typ SZ ref () val = case typ of
-  STPair t1 t2 ->
-    (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
-             accumAddSparse t2 SZ (snd ref) () (snd val))
-  STEither t1 t2 ->
-    case (ref, val) of
-      (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val')
-      (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val')
-      _ -> error "Mismatched Either in accumAddDense either"
-  _ -> error "accumAddDense: invalid dense type"
-
-accumAddDense typ (SS dep) ref idx val = case typ of
-  STPair t1 t2 ->
-    case (idx, val) of
-      (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val')
-      (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val')
-      _ -> error "Mismatched Either in accumAddDense pair"
-  STEither t1 t2 ->
-    case (ref, idx, val) of
-      (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val')
-      (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
-      _ -> error "Mismatched Either in accumAddDense either"
-  _ -> error "accumAddDense: invalid dense type"
--}
+  AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
+    -- Oops, ref's contents was still sparse. Have to initialise
+    -- it first, then try again.
+    Nothing -> (ac, do val <- makeval
+                       join $ atomicModifyIORef' ref $ \ac' -> case ac' of
+                                Nothing -> (Just val, return ())
+                                Just val' -> (ac', unAcM $ modifyval val'))
+    -- Yep, ref already had a value in there, so we can just add
+    -- val' to it recursively.
+    Just val -> (ac, unAcM $ modifyval val)
 
 
 numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
@@ -593,12 +429,12 @@ integralIsIntegral STI64 = id
 
 unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
             -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
-unTupRepIdx nil _    SZ _ = nil
+unTupRepIdx nil _    SZ     _        = nil
 unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
 
 tupRepIdx :: (forall m. f (S m) -> (f m, Int))
           -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
-tupRepIdx _      SZ _ = ()
+tupRepIdx _      SZ     _   = ()
 tupRepIdx uncons (SS n) tup =
   let (tup', i) = uncons tup
   in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i
-- 
cgit v1.2.3-70-g09d2