aboutsummaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs398
1 files changed, 187 insertions, 211 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index ddc3479..ffc2929 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,9 +21,11 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
+import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
import System.IO (hPutStrLn, stderr)
@@ -34,7 +36,7 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
-import CHAD.Types
+import AST.Sparse.Types
import Data
import Interpreter.Rep
@@ -48,35 +50,39 @@ runAcM (AcM m) = unsafePerformIO m
acmDebugLog :: String -> AcM s ()
acmDebugLog s = AcM (hPutStrLn stderr s)
+data V t = V (STy t) (Rep t)
+
interpret :: Ex '[] t -> Rep t
-interpret = interpretOpen False SNil
+interpret = interpretOpen False SNil SNil
-- | Bool: whether to trace execution with debug prints (very verbose)
-interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t
-interpretOpen prints env e =
+interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env venv e =
runAcM $
let ?depth = 0
?prints = prints
- in interpret' env e
+ in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e
-interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int)
+ => SList V env -> Ex env t -> AcM s (Rep t)
interpret' env e = do
+ let tenv = slistMap (\(V t _) -> t) env
let dep = ?depth
let lenlimit = max 20 (100 - dep)
let replace a b = map (\c -> if c == a then b else c)
let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
| otherwise = replace '\n' ' ' s
- when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e)
res <- let ?depth = dep + 1 in interpret'Rec env e
when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
return res
-interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t)
interpret'Rec env = \case
- EVar _ _ i -> case slistIdx env i of Value x -> return x
+ EVar _ _ i -> case slistIdx env i of V _ x -> return x
ELet _ a b -> do
x <- interpret' env a
- let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b
+ let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b
expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
@@ -84,18 +90,32 @@ interpret'Rec env = \case
ENil _ -> return ()
EInl _ _ e -> Left <$> interpret' env e
EInr _ _ e -> Right <$> interpret' env e
- ECase _ e a b -> interpret' env e >>= \case
- Left x -> interpret' (Value x `SCons` env) a
- Right y -> interpret' (Value y `SCons` env) b
+ ECase _ e a b ->
+ let STEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Left x -> interpret' (V t1 x `SCons` env) a
+ Right y -> interpret' (V t2 y `SCons` env) b
ENothing _ _ -> return Nothing
EJust _ e -> Just <$> interpret' env e
- EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
+ EMaybe _ a b e ->
+ let STMaybe t1 = typeOf e
+ in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
+ ELNil _ _ _ -> return Nothing
+ ELInl _ _ e -> Just . Left <$> interpret' env e
+ ELInr _ _ e -> Just . Right <$> interpret' env e
+ ELCase _ e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Nothing -> interpret' env a
+ Just (Left x) -> interpret' (V t1 x `SCons` env) b
+ Just (Right y) -> interpret' (V t2 y `SCons` env) c
EConstArr _ _ _ v -> return v
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
- arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
EFold1Inner _ _ a b c -> do
- let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
+ let t = typeOf b
+ let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
@@ -126,34 +146,39 @@ interpret'Rec env = \case
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
- EIdx _ a b
- | STArr n _ <- typeOf a
- -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EIdx _ a b ->
+ let STArr n _ = typeOf a
+ in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
- ECustom _ _ _ _ pr _ _ e1 e2 -> do
+ ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
e1' <- interpret' env e1
e2' <- interpret' env e2
- interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
+ interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
+ ERecompute _ e -> interpret' env e
EWith _ t e1 e2 -> do
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
- interpret' (Value accum `SCons` env) e2
- EAccum _ t p e1 e2 e3 -> do
+ interpret' (V (STAccum t) accum `SCons` env) e2
+ EAccum _ t p e1 sp e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t p accum idx val
- EZero _ t -> do
- return $ zeroD2 t
+ accumAddSparseD t p accum idx sp val
+ EZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ addD2s t a' b'
+ return $ addM t a' b'
EOneHot _ t p a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ onehotD2 p t a' b'
+ return $ onehotM p t a' b'
EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -174,6 +199,7 @@ interpretOp op arg = case op of
OExp st -> floatingIsFractional st $ exp arg
OLog st -> floatingIsFractional st $ log arg
OIDiv st -> integralIsIntegral st $ uncurry quot arg
+ OMod st -> integralIsIntegral st $ uncurry rem arg
where
styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
styIsEq STI32 = id
@@ -182,211 +208,161 @@ interpretOp op arg = case op of
styIsEq STF64 = id
styIsEq STBool = id
-zeroD2 :: STy t -> Rep (D2 t)
-zeroD2 typ = case typ of
- STNil -> ()
- STPair _ _ -> Nothing
- STEither _ _ -> Nothing
- STMaybe _ -> Nothing
- STArr _ _ -> Nothing
- STScal sty -> case sty of
- STI32 -> ()
- STI64 -> ()
+zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t
+zeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi))
+ SMTLEither _ _ -> Nothing
+ SMTMaybe _ -> Nothing
+ SMTArr _ t -> arrayMap (zeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
STF32 -> 0.0
STF64 -> 0.0
- STBool -> ()
- STAccum{} -> error "Zero of Accum"
-addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
-addD2s typ a b = case typ of
- STNil -> ()
- STPair t1 t2 -> case (a, b) of
- (Nothing, _) -> b
- (_, Nothing) -> a
- (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2)
- STEither t1 t2 -> case (a, b) of
- (Nothing, _) -> b
- (_, Nothing) -> a
- (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y))
- (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y))
- _ -> error "Plus of inconsistent Eithers"
- STMaybe t -> case (a, b) of
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
+addM :: SMTy t -> Rep t -> Rep t -> Rep t
+addM typ a b = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b))
+ SMTLEither t1 t2 -> case (a, b) of
(Nothing, _) -> b
(_, Nothing) -> a
- (Just x, Just y) -> Just (addD2s t x y)
- STArr _ t -> case (a, b) of
+ (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y))
+ (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y))
+ _ -> error "Plus of inconsistent LEithers"
+ SMTMaybe t -> case (a, b) of
(Nothing, _) -> b
(_, Nothing) -> a
- (Just x, Just y) ->
- let sh1 = arrayShape x
- sh2 = arrayShape y
- in if | shapeSize sh1 == 0 -> Just y
- | shapeSize sh2 == 0 -> Just x
- | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i))
- | otherwise -> error "Plus of inconsistently shaped arrays"
- STScal sty -> case sty of
- STI32 -> ()
- STI64 -> ()
- STF32 -> a + b
- STF64 -> a + b
- STBool -> ()
- STAccum{} -> error "Plus of Accum"
-
-onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a)
-onehotD2 SAPHere _ _ val = val
-onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b)
-onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val)
-onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val))
-onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
-onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
-onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
- Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
-
-withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
+ (Just x, Just y) -> Just (addM t x y)
+ SMTArr _ t ->
+ let sh1 = arrayShape a
+ sh2 = arrayShape b
+ in if | shapeSize sh1 == 0 -> b
+ | shapeSize sh2 == 0 -> a
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
+ SMTScal sty -> numericIsNum sty $ a + b
+
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
+onehotM SAPHere _ _ val = val
+onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
+onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
+onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val))
+onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val))
+onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val)
+onehotM (SAPArrIdx prj) (SMTArr n a) idx val =
+ runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx
+
+withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- accum <- newAcSparse t SAPHere () initval
+ accum <- newAcDense t initval
out <- unAcM $ f accum
- val <- readAcSparse t accum
+ val <- readAc t accum
return (out, val)
-newAcZero :: STy t -> IO (RepAc t)
-newAcZero = \case
- STNil -> return ()
- STPair{} -> newIORef Nothing
- STEither{} -> newIORef Nothing
- STMaybe _ -> newIORef Nothing
- STArr _ _ -> newIORef Nothing
- STScal sty -> case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> newIORef 0.0
- STF64 -> newIORef 0.0
- STBool -> return ()
- STAccum{} -> error "Nested accumulators"
-
-newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)
-newAcSparse typ prj idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
- (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
- (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
- (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
- (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val
- (STScal sty, SAPHere) -> case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> newIORef val
- STF64 -> newIORef val
- STBool -> return ()
-
- (STPair t1 t2, SAPFst prj') ->
- newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2
- (STPair t1 t2, SAPSnd prj') ->
- newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val
-
- (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
- (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
-
- (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
-
- (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
-
-newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a))
-newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx
+newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
+newAcDense typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val
+ SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val
+ SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val
+ SMTArr _ t1 -> arrayMapM (newAcDense t1) val
+ SMTScal _ -> newIORef val
onehotArray :: Monad m
- => (Rep (AcIdx p a) -> m v) -- ^ the "one"
- -> m v -- ^ the "zero"
- -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
-onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
+ -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
+onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ arrsh = arrayShape ziarr
!linindex = toLinearIndex arrsh arrindex
- in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero)
-
-readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))
-readAcSparse typ val = case typ of
- STNil -> return ()
- STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- STMaybe t -> traverse (readAcSparse t) =<< readIORef val
- STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val
- STScal sty -> case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> readIORef val
- STF64 -> readIORef val
- STBool -> return ()
- STAccum{} -> error "Nested accumulators"
-
-accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
-
- (STPair t1 t2, SAPHere) ->
- case val of
- Nothing -> return ()
- Just (val1, val2) ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
- <*> newAcSparse t2 SAPHere () val2)
- (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
- accumAddSparse t2 SAPHere ac2 () val2)
- (STPair t1 t2, SAPFst prj') ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
- (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
- (STPair t1 t2, SAPSnd prj') ->
- realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
- (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
-
- (STEither{}, SAPHere) ->
+ in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
+
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
+ SMTScal _ -> readIORef val
+
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
+ (SMTLEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
+
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
+
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let (arrindex', idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ref
+ linindex = toLinearIndex arrsh arrindex
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
case val of
Nothing -> return ()
- Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
- Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- (STEither t1 _, SAPLeft prj') ->
- realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
- (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
- Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
- (STEither _ t2, SAPRight prj') ->
- realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
- (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
- Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
-
- (STMaybe{}, SAPHere) ->
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
case val of
Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- (STMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
-
- (STArr _ t1, SAPHere) ->
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
case val of
Nothing -> return ()
Just val' ->
- realiseMaybeSparse ref
- (arrayMapM (newAcSparse t1 SAPHere ()) val')
- (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
- accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
- (STArr n t1, SAPArrIdx prj' _) ->
- let ((arrindex', arrsh'), idx') = idx
- arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
- linindex = toLinearIndex arrsh arrindex
- in realiseMaybeSparse ref
- (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
- (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
-
- (STScal sty, SAPHere) -> AcM $ case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> return ()
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
-
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
-- Try modifying what's already in ref. The 'join' makes the snd