diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 398 |
1 files changed, 187 insertions, 211 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index ddc3479..ffc2929 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,9 +21,11 @@ module Interpreter ( ) where import Control.Monad (foldM, join, when, forM_) +import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity +import qualified Data.Functor.Product as Product import Data.Int (Int64) import Data.IORef import System.IO (hPutStrLn, stderr) @@ -34,7 +36,7 @@ import Debug.Trace import Array import AST import AST.Pretty -import CHAD.Types +import AST.Sparse.Types import Data import Interpreter.Rep @@ -48,35 +50,39 @@ runAcM (AcM m) = unsafePerformIO m acmDebugLog :: String -> AcM s () acmDebugLog s = AcM (hPutStrLn stderr s) +data V t = V (STy t) (Rep t) + interpret :: Ex '[] t -> Rep t -interpret = interpretOpen False SNil +interpret = interpretOpen False SNil SNil -- | Bool: whether to trace execution with debug prints (very verbose) -interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t -interpretOpen prints env e = +interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t +interpretOpen prints env venv e = runAcM $ let ?depth = 0 ?prints = prints - in interpret' env e + in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e -interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) + => SList V env -> Ex env t -> AcM s (Rep t) interpret' env e = do + let tenv = slistMap (\(V t _) -> t) env let dep = ?depth let lenlimit = max 20 (100 - dep) let replace a b = map (\c -> if c == a then b else c) let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." | otherwise = replace '\n' ' ' s - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e) + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) res <- let ?depth = dep + 1 in interpret'Rec env e when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" return res -interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) interpret'Rec env = \case - EVar _ _ i -> case slistIdx env i of Value x -> return x + EVar _ _ i -> case slistIdx env i of V _ x -> return x ELet _ a b -> do x <- interpret' env a - let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b + let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e @@ -84,18 +90,32 @@ interpret'Rec env = \case ENil _ -> return () EInl _ _ e -> Left <$> interpret' env e EInr _ _ e -> Right <$> interpret' env e - ECase _ e a b -> interpret' env e >>= \case - Left x -> interpret' (Value x `SCons` env) a - Right y -> interpret' (Value y `SCons` env) b + ECase _ e a b -> + let STEither t1 t2 = typeOf e + in interpret' env e >>= \case + Left x -> interpret' (V t1 x `SCons` env) a + Right y -> interpret' (V t2 y `SCons` env) b ENothing _ _ -> return Nothing EJust _ e -> Just <$> interpret' env e - EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e + EMaybe _ a b e -> + let STMaybe t1 = typeOf e + in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e + ELNil _ _ _ -> return Nothing + ELInl _ _ e -> Just . Left <$> interpret' env e + ELInr _ _ e -> Just . Right <$> interpret' env e + ELCase _ e a b c -> + let STLEither t1 t2 = typeOf e + in interpret' env e >>= \case + Nothing -> interpret' env a + Just (Left x) -> interpret' (V t1 x `SCons` env) b + Just (Right y) -> interpret' (V t2 y `SCons` env) c EConstArr _ _ _ v -> return v EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) + arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) EFold1Inner _ _ a b c -> do - let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a + let t = typeOf b + let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr @@ -126,34 +146,39 @@ interpret'Rec env = \case EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ a b - | STArr n _ <- typeOf a - -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EIdx _ a b -> + let STArr n _ = typeOf a + in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e - ECustom _ _ _ _ pr _ _ e1 e2 -> do + ECustom _ t1 t2 _ pr _ _ e1 e2 -> do e1' <- interpret' env e1 e2' <- interpret' env e2 - interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr + interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr + ERecompute _ e -> interpret' env e EWith _ t e1 e2 -> do initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> - interpret' (Value accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do + interpret' (V (STAccum t) accum `SCons` env) e2 + EAccum _ t p e1 sp e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t p accum idx val - EZero _ t -> do - return $ zeroD2 t + accumAddSparseD t p accum idx sp val + EZero _ t ezi -> do + zi <- interpret' env ezi + return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b - return $ addD2s t a' b' + return $ addM t a' b' EOneHot _ t p a b -> do a' <- interpret' env a b' <- interpret' env b - return $ onehotD2 p t a' b' + return $ onehotM p t a' b' EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -174,6 +199,7 @@ interpretOp op arg = case op of OExp st -> floatingIsFractional st $ exp arg OLog st -> floatingIsFractional st $ log arg OIDiv st -> integralIsIntegral st $ uncurry quot arg + OMod st -> integralIsIntegral st $ uncurry rem arg where styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r styIsEq STI32 = id @@ -182,211 +208,161 @@ interpretOp op arg = case op of styIsEq STF64 = id styIsEq STBool = id -zeroD2 :: STy t -> Rep (D2 t) -zeroD2 typ = case typ of - STNil -> () - STPair _ _ -> Nothing - STEither _ _ -> Nothing - STMaybe _ -> Nothing - STArr _ _ -> Nothing - STScal sty -> case sty of - STI32 -> () - STI64 -> () +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) + SMTLEither _ _ -> Nothing + SMTMaybe _ -> Nothing + SMTArr _ t -> arrayMap (zeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 - STBool -> () - STAccum{} -> error "Zero of Accum" -addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -addD2s typ a b = case typ of - STNil -> () - STPair t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) - STEither t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) - (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) - _ -> error "Plus of inconsistent Eithers" - STMaybe t -> case (a, b) of +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of + SMTNil -> () + SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) + SMTLEither t1 t2 -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a - (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> case (a, b) of + (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) + _ -> error "Plus of inconsistent LEithers" + SMTMaybe t -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a - (Just x, Just y) -> - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 -> Just y - | shapeSize sh2 == 0 -> Just x - | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) - | otherwise -> error "Plus of inconsistently shaped arrays" - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> a + b - STF64 -> a + b - STBool -> () - STAccum{} -> error "Plus of Accum" - -onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) -onehotD2 SAPHere _ _ val = val -onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) -onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) -onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) -onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) -onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = - Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx - -withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) + (Just x, Just y) -> Just (addM t x y) + SMTArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + SMTScal sty -> numericIsNum sty $ a + b + +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx + +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - accum <- newAcSparse t SAPHere () initval + accum <- newAcDense t initval out <- unAcM $ f accum - val <- readAcSparse t accum + val <- readAc t accum return (out, val) -newAcZero :: STy t -> IO (RepAc t) -newAcZero = \case - STNil -> return () - STPair{} -> newIORef Nothing - STEither{} -> newIORef Nothing - STMaybe _ -> newIORef Nothing - STArr _ _ -> newIORef Nothing - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef 0.0 - STF64 -> newIORef 0.0 - STBool -> return () - STAccum{} -> error "Nested accumulators" - -newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val - (STScal sty, SAPHere) -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef val - STF64 -> newIORef val - STBool -> return () - - (STPair t1 t2, SAPFst prj') -> - newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 - (STPair t1 t2, SAPSnd prj') -> - newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val - - (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val - - (STAccum{}, _) -> error "Accumulators not allowed in source program" - -newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) -newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val + SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val + SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val + SMTArr _ t1 -> arrayMapM (newAcDense t1) val + SMTScal _ -> newIORef val onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" - -> m v -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" + -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' + arrsh = arrayShape ziarr !linindex = toLinearIndex arrsh arrindex - in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) - -readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) -readAcSparse typ val = case typ of - STNil -> return () - STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> readIORef val - STF64 -> readIORef val - STBool -> return () - STAccum{} -> error "Nested accumulators" - -accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - - (STPair t1 t2, SAPHere) -> - case val of - Nothing -> return () - Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 - <*> newAcSparse t2 SAPHere () val2) - (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 - accumAddSparse t2 SAPHere ac2 () val2) - (STPair t1 t2, SAPFst prj') -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) - (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) - (STPair t1 t2, SAPSnd prj') -> - realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) - (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - - (STEither{}, SAPHere) -> + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) + +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val + SMTMaybe t -> traverse (readAc t) =<< readIORef val + SMTArr _ t -> traverse (readAc t) val + SMTScal _ -> readIORef val + +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val + +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> case val of Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - (STEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (STEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - - (STMaybe{}, SAPHere) -> + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> case val of Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - (STMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) - - (STArr _ t1, SAPHere) -> + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> case val of Nothing -> return () Just val' -> - realiseMaybeSparse ref - (arrayMapM (newAcSparse t1 SAPHere ()) val') - (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> - accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj' _) -> - let ((arrindex', arrsh'), idx') = idx - arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' - linindex = toLinearIndex arrsh arrindex - in realiseMaybeSparse ref - (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) - (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - - (STScal sty, SAPHere) -> AcM $ case sty of - STI32 -> return () - STI64 -> return () - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> return () - - (STAccum{}, _) -> error "Accumulators not allowed in source program" - + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +-- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = -- Try modifying what's already in ref. The 'join' makes the snd |