From 1d748ea62d02e4f66fd0f8be9815b8c3843f8356 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 15 Sep 2024 22:14:11 +0200
Subject: WIP Accum stuff

---
 src/Array.hs           |   5 +-
 src/Interpreter.hs     | 198 ++++++++++++++++++++++++++++++++++++++++++++-----
 src/Interpreter/Rep.hs |  37 +++++----
 3 files changed, 203 insertions(+), 37 deletions(-)

diff --git a/src/Array.hs b/src/Array.hs
index 0d585a9..d7dadbf 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -1,6 +1,7 @@
-{-# LANGUAGE KindSignatures #-}
 {-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveTraversable #-}
 {-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
 {-# LANGUAGE StandaloneDeriving #-}
 {-# LANGUAGE TupleSections #-}
 module Array where
@@ -47,7 +48,7 @@ emptyShape (SS m) = emptyShape m `ShCons` 0
 
 -- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
 data Array (n :: Nat) t = Array (Shape n) (Vector t)
-  deriving (Show)
+  deriving (Show, Functor, Foldable, Traversable)
 
 arrayShape :: Array n t -> Shape n
 arrayShape (Array sh _) = sh
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 8728ec0..f58cefb 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -9,15 +9,16 @@
 {-# LANGUAGE TypeOperators #-}
 {-# LANGUAGE DerivingStrategies #-}
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE TupleSections #-}
 module Interpreter (
   interpret,
   interpret',
   Value,
 ) where
 
-import Control.Monad (foldM)
+import Control.Monad (foldM, join)
 import Data.Int (Int64)
-import Data.Proxy
+import Data.IORef
 import System.IO.Unsafe (unsafePerformIO)
 
 import Array
@@ -25,9 +26,10 @@ import AST
 import CHAD.Types
 import Data
 import Interpreter.Rep
+import Data.Bifunctor (first)
 
 
-newtype AcM s a = AcM (IO a)
+newtype AcM s a = AcM { unAcM :: IO a }
   deriving newtype (Functor, Applicative, Monad)
 
 runAcM :: (forall s. AcM s a) -> a
@@ -53,9 +55,9 @@ interpret' env = \case
   ECase _ e a b -> interpret' env e >>= \case
                      Left x -> interpret' (Value x `SCons` env) a
                      Right y -> interpret' (Value y `SCons` env) b
-  ENothing _ _ -> _
-  EJust _ _ -> _
-  EMaybe _ _ _ _ -> _
+  ENothing _ _ -> return Nothing
+  EJust _ e -> Just <$> interpret' env e
+  EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
   EConstArr _ _ _ v -> return v
   EBuild1 _ a b -> do
     n <- fromIntegral @Int64 @Int <$> interpret' env a
@@ -88,19 +90,20 @@ interpret' env = \case
   EOp _ op e -> interpretOp op <$> interpret' env e
   EWith e1 e2 -> do
     initval <- interpret' env e1
-    withAccum (typeOf e1) initval $ \accum ->
+    withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
       interpret' (Value accum `SCons` env) e2
   EAccum i e1 e2 e3 -> do
+    let STAccum t = typeOf e3
     idx <- interpret' env e1
     val <- interpret' env e2
     accum <- interpret' env e3
-    accumAdd accum i idx val
+    accumAddSparse t i accum idx val
   EZero t -> do
-    return $ makeZero t
+    return $ zeroD2 t
   EPlus t a b -> do
     a' <- interpret' env a
     b' <- interpret' env b
-    return $ makePlus t a' b'
+    return $ addD2s t a' b'
   EError _ s -> error $ "Interpreter: Program threw error: " ++ s
 
 interpretOp :: SOp a t -> Rep a -> Rep t
@@ -114,8 +117,8 @@ interpretOp op arg = case op of
   ONot -> not arg
   OIf -> if arg then Left () else Right ()
 
-makeZero :: STy t -> Rep (D2 t)
-makeZero typ = case typ of
+zeroD2 :: STy t -> Rep (D2 t)
+zeroD2 typ = case typ of
   STNil -> ()
   STPair _ _ -> Left ()
   STEither _ _ -> Left ()
@@ -129,25 +132,29 @@ makeZero typ = case typ of
                   STBool -> ()
   STAccum{} -> error "Zero of Accum"
 
-makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
-makePlus typ a b = case typ of
+addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
+addD2s typ a b = case typ of
   STNil -> ()
   STPair t1 t2 -> case (a, b) of
     (Left (), _) -> b
     (_, Left ()) -> a
-    (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2)
+    (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2)
   STEither t1 t2 -> case (a, b) of
     (Left (), _) -> b
     (_, Left ()) -> a
-    (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y))
-    (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y))
+    (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y))
+    (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y))
     _ -> error "Plus of inconsistent Eithers"
+  STMaybe t -> case (a, b) of
+    (Nothing, _) -> b
+    (_, Nothing) -> a
+    (Just x, Just y) -> Just (addD2s t x y)
   STArr _ t ->
     let sh1 = arrayShape a
         sh2 = arrayShape b
     in if | shapeSize sh1 == 0 -> b
           | shapeSize sh2 == 0 -> a
-          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i))
+          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i))
           | otherwise -> error "Plus of inconsistently shaped arrays"
   STScal sty -> case sty of
     STI32 -> ()
@@ -157,6 +164,159 @@ makePlus typ a b = case typ of
     STBool -> ()
   STAccum{} -> error "Plus of Accum"
 
+withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+withAccum t _ initval f = AcM $ do
+  accum <- newAcSparse t initval
+  out <- case f accum of AcM m -> m
+  val <- readAcSparse t accum
+  return (out, val)
+
+newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
+newAcSparse typ val = case typ of
+  STNil -> return ()
+  STPair{} -> newIORef =<<newAcDense typ val
+  STMaybe t -> newIORef =<< traverse (newAcDense t) val
+  STArr _ t -> newIORef =<< traverse (newAcSparse t) val
+  STScal{} -> newIORef val
+  STAccum{} -> error "Nested accumulators"
+  STEither{} -> error "Bare Either in accumulator"
+
+newAcDense :: STy t -> Rep t -> IO (RepAcDense t)
+newAcDense typ val = case typ of
+  STNil -> return ()
+  STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val)
+  STEither t1 t2 -> case val of
+    Left x -> Left <$> newAcSparse t1 x
+    Right y -> Right <$> newAcSparse t2 y
+  STMaybe t -> traverse (newAcSparse t) val
+  STArr _ t -> traverse (newAcSparse t) val
+  STScal{} -> return val
+  STAccum{} -> error "Nested accumulators"
+
+readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
+readAcSparse typ val = case typ of
+  STNil -> return ()
+  STPair t1 t2 -> do
+    (a, b) <- readIORef val
+    (,) <$> readAcSparse t1 a <*> readAcSparse t2 b
+  STMaybe t -> traverse (readAcDense t) =<< readIORef val
+  STArr _ t -> traverse (readAcSparse t) =<< readIORef val
+  STScal{} -> readIORef val
+  STAccum{} -> error "Nested accumulators"
+  STEither{} -> error "Bare Either in accumulator"
+
+readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
+readAcDense typ val = case typ of
+  STNil -> return ()
+  STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
+  STEither t1 t2 -> case val of
+    Left x -> Left <$> readAcSparse t1 x
+    Right y -> Right <$> readAcSparse t2 y
+  STMaybe t -> traverse (readAcSparse t) val
+  STArr _ t -> traverse (readAcSparse t) val
+  STScal{} -> return val
+  STAccum{} -> error "Nested accumulators"
+
+accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAddSparse typ SZ ref () val = case typ of
+  STNil -> return ()
+  STPair t1 t2 -> AcM $ do
+    (r1, r2) <- readIORef ref
+    unAcM $ accumAddSparse t1 SZ r1 () (fst val)
+    unAcM $ accumAddSparse t2 SZ r2 () (snd val)
+  STMaybe t ->
+    join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of
+                   (Nothing, _) -> (ac, _)
+                   (Just{}, Nothing) -> (ac, return ())
+                   (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val')
+  STArr _ t -> _ ref val
+  STScal{} -> _ ref val
+  STAccum{} -> error "Nested accumulators"
+  STEither{} -> error "Bare Either in accumulator"
+accumAddSparse typ (SS dep) ref idx val = case typ of
+  STNil -> return ()
+  STPair t1 t2 -> _ ref idx val
+  STMaybe t -> _ ref idx val
+  STArr _ t -> _ ref idx val
+  STScal{} -> _ ref idx val
+  STAccum{} -> error "Nested accumulators"
+  STEither{} -> error "Bare Either in accumulator"
+
+accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
+accumAddDense = _
+
+-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ())
+-- accumAddVal typ SZ ac () val = case typ of
+--   STNil -> ((), return ())
+--   STPair t1 t2 ->
+--     let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val)
+--         (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val)
+--     in ((ac1', ac2'), m1 >> m2)
+--   STMaybe t -> case t of
+--     STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val)
+--     STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def
+--     where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ())
+--           def = (ac, accumAddValM t ac val)
+--   STArr n t
+--     | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val
+--   STEither{} -> error "Bare Either in accumulator"
+--   _ -> _
+-- accumAddVal typ (SS dep) ac idx val = case typ of
+--   STNil -> ((), return ())
+--   STPair t1 t2 ->
+--     case (idx, val) of
+--       (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val'
+--       (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val'
+--       _ -> error "Inconsistent idx and val in accumulator add operation"
+--   _ -> _
+
+-- accumAddValME :: STy a -> STy b
+--               -> IORef (Maybe (Either (RepAc a) (RepAc b)))
+--               -> Maybe (Either (Rep a) (Rep b))
+--               -> AcM s ()
+-- accumAddValME t1 t2 ac val =
+--   case val of
+--     Nothing -> return ()
+--     Just val' ->
+--       join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of
+--                      (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val))
+--                      (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1
+--                      (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2
+--                      _ -> error "Inconsistent accumulator and value in add operation on Maybe Either"
+
+-- initAccumOrTryAgainME :: STy a -> STy b
+--                       -> IORef (Maybe (Either (RepAc a) (RepAc b)))
+--                       -> Either (Rep a) (Rep b)
+--                       -> IO ()
+--                       -> IO ()
+-- initAccumOrTryAgainME t1 t2 ac val onRace = do
+--   newContents <- case val of Left x -> Left <$> makeRepAc t1 x
+--                              Right y -> Right <$> makeRepAc t2 y
+--   join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
+--                                       value@Just{} -> (value, onRace))
+
+-- accumAddValM :: STy t
+--              -> IORef (Maybe (RepAc t))
+--              -> Maybe (Rep t)
+--              -> AcM s ()
+-- accumAddValM t ac val =
+--   case val of
+--     Nothing -> return ()
+--     Just val' ->
+--       join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of
+--                      Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val))
+--                      Just x -> first Just $ accumAddVal t SZ x () val'
+
+-- initAccumOrTryAgainM :: STy t
+--                      -> IORef (Maybe (RepAc t))
+--                      -> Rep t
+--                      -> IO ()
+--                      -> IO ()
+-- initAccumOrTryAgainM t ac val onRace = do
+--   newContents <- makeRepAc t val
+--   join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
+--                                       value@Just{} -> (value, onRace))
+
 numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
 numericIsNum STI32 = id
 numericIsNum STI64 = id
@@ -166,7 +326,7 @@ numericIsNum STF64 = id
 unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
             -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
 unTupRepIdx nil _    SZ _ = nil
-unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
+unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
 
 tupRepIdx :: (forall m. f (S m) -> (f m, Int))
           -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 7add442..680196c 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -1,9 +1,9 @@
 {-# LANGUAGE DataKinds #-}
 {-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
 module Interpreter.Rep where
 
 import Data.IORef
-import qualified Data.Vector.Mutable as MV
 import GHC.TypeError
 
 import Array
@@ -17,19 +17,24 @@ type family Rep t where
   Rep (TMaybe t) = Maybe (Rep t)
   Rep (TArr n t) = Array n (Rep t)
   Rep (TScal sty) = ScalRep sty
-  Rep (TAccum t) = IORef (RepAc t)
+  Rep (TAccum t) = RepAcSparse t
 
-type family RepAc t where
-  RepAc TNil = ()
-  RepAc (TPair a b) = (RepAc a, RepAc b)
-  -- This is annoying when working with values of type 'RepAc t', because
-  -- failing a pattern match does not generate negative type information.
-  -- However, it works, saves us from having to defining a LEither type
-  -- first-class in the type system with
-  --   Rep (LEither a b) = Maybe (Either a b)
-  -- and it's not even incorrect, in a way.
-  RepAc (TMaybe (TEither a b)) = IORef (Maybe (Either (RepAc a) (RepAc b)))
-  RepAc (TMaybe t) = IORef (Maybe (RepAc t))
-  RepAc (TArr n t) = (Shape n, MV.IOVector (RepAc t))
-  RepAc (TScal sty) = IORef (ScalRep sty)
-  RepAc (TAccum t) = TypeError (Text "Nested accumulators")
+-- Mutable, and has an O(1) zero.
+type family RepAcSparse t where
+  RepAcSparse TNil = ()
+  RepAcSparse (TPair a b) = IORef (RepAcDense (TPair a b))
+  RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid")
+  RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t))  -- allow the value to be dense, because the Maybe's zero can be used for the contents
+  RepAcSparse (TArr n t) = IORef (RepAcDense (TArr n t))  -- empty array is zero
+  RepAcSparse (TScal sty) = IORef (ScalRep sty)
+  RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
+
+-- Immutable, and does not necessarily have a zero.
+type family RepAcDense t where
+  RepAcDense TNil = ()
+  RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
+  RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b)
+  RepAcDense (TMaybe t) = Maybe (RepAcSparse t)
+  RepAcDense (TArr n t) = Array n (RepAcSparse t)
+  RepAcDense (TScal sty) = ScalRep sty
+  RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators")
-- 
cgit v1.2.3-70-g09d2