summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-21 09:57:45 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-21 09:57:45 +0200
commitb5ed3d2fcc249cb410b9e86d25e9ef808c6dba97 (patch)
tree66383b16d5d95f939aaa165a783dbbfd99a57fe3 /src/Interpreter.hs
parent8bbc2d2867e3d0a4a1f2810b40e92175779822e1 (diff)
parenta4b3eb76acbec30ffeae119a4dc6e4c9f64396fe (diff)
Merge branch 'sparse'HEADmaster
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs138
1 files changed, 72 insertions, 66 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 803a24a..ffc2929 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,6 +21,7 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
@@ -35,6 +36,7 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
+import AST.Sparse.Types
import Data
import Interpreter.Rep
@@ -158,14 +160,17 @@ interpret'Rec env = \case
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
interpret' (V (STAccum t) accum `SCons` env) e2
- EAccum _ t p e1 e2 e3 -> do
+ EAccum _ t p e1 sp e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t p accum idx val
+ accumAddSparseD t p accum idx sp val
EZero _ t ezi -> do
zi <- interpret' env ezi
return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
@@ -216,6 +221,19 @@ zeroM typ zi = case typ of
STF32 -> 0.0
STF64 -> 0.0
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
addM :: SMTy t -> Rep t -> Rep t -> Rep t
addM typ a b = case typ of
SMTNil -> ()
@@ -239,7 +257,7 @@ addM typ a b = case typ of
| otherwise -> error "Plus of inconsistently shaped arrays"
SMTScal sty -> numericIsNum sty $ a + b
-onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
onehotM SAPHere _ _ val = val
onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
@@ -256,15 +274,6 @@ withAccum t _ initval f = AcM $ do
val <- readAc t accum
return (out, val)
-newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t)
-newAcZero typ zi = case typ of
- SMTNil -> return ()
- SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi
- SMTLEither{} -> newIORef Nothing
- SMTMaybe _ -> newIORef Nothing
- SMTArr _ t -> arrayMapM (newAcZero t) zi
- SMTScal sty -> numericIsNum sty $ newIORef 0
-
newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
newAcDense typ val = case typ of
SMTNil -> return ()
@@ -274,26 +283,10 @@ newAcDense typ val = case typ of
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
-newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a)
-newAcSparse typ prj idx val = case (typ, prj) of
- (_, SAPHere) -> newAcDense typ val
-
- (SMTPair t1 t2, SAPFst prj') ->
- (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx)
- (SMTPair t1 t2, SAPSnd prj') ->
- (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val
-
- (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
- (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
-
- (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
-
- (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx
-
onehotArray :: Monad m
- => (Rep (AcIdx p a) -> m v) -- ^ the "one"
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
-> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
- -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = arrayShape ziarr
@@ -309,54 +302,67 @@ readAc typ val = case typ of
SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
-accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s ()
-accumAddDense typ ref val = case typ of
- SMTNil -> return ()
- SMTPair t1 t2 -> do
- accumAddDense t1 (fst ref) (fst val)
- accumAddDense t2 (snd ref) (snd val)
- SMTLEither{} ->
- case val of
- Nothing -> return ()
- Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
- Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- SMTMaybe{} ->
- case val of
- Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- SMTArr _ t1 ->
- forM_ [0 .. arraySize ref - 1] $ \i ->
- accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
- SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
-
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (_, SAPHere) -> accumAddDense typ ref val
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
- (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val
- (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
(SMTLEither t1 _, SAPLeft prj') ->
- realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
- (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
- Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
(SMTLEither _ t2, SAPRight prj') ->
- realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
- (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
- Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
(SMTMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
(SMTArr n t1, SAPArrIdx prj') ->
- let ((arrindex', ziarr), idx') = idx
+ let (arrindex', idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = arrayShape ziarr
+ arrsh = arrayShape ref
linindex = toLinearIndex arrsh arrindex
- in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ case val of
+ Nothing -> return ()
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
-- Try modifying what's already in ref. The 'join' makes the snd