summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-16 23:21:55 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-16 23:21:55 +0200
commit2b1a40b5933b8b0dceaae744e5b70cb604822c9d (patch)
tree652d6d88efd2b0b4502819297333305cec5242c4 /src/Interpreter.hs
parenteed0f2999d6f6c8485ef53deb38f9d0a67b4f88e (diff)
CHAD.hs compiles
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs39
1 files changed, 33 insertions, 6 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 803a24a..b3576ce 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -162,7 +162,7 @@ interpret'Rec env = \case
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t p accum idx val
+ accumAddSparseD t p accum idx val
EZero _ t ezi -> do
zi <- interpret' env ezi
return $ zeroM t zi
@@ -239,7 +239,7 @@ addM typ a b = case typ of
| otherwise -> error "Plus of inconsistently shaped arrays"
SMTScal sty -> numericIsNum sty $ a + b
-onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
onehotM SAPHere _ _ val = val
onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
@@ -274,7 +274,7 @@ newAcDense typ val = case typ of
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
-newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a)
+newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a)
newAcSparse typ prj idx val = case (typ, prj) of
(_, SAPHere) -> newAcDense typ val
@@ -291,9 +291,9 @@ newAcSparse typ prj idx val = case (typ, prj) of
(SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx
onehotArray :: Monad m
- => (Rep (AcIdx p a) -> m v) -- ^ the "one"
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
-> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
- -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = arrayShape ziarr
@@ -329,7 +329,34 @@ accumAddDense typ ref val = case typ of
accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s ()
+accumAddSparseD typ prj ref idx val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
+ (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
+ (SMTLEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
+ (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
+
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+ (\ac -> accumAddSparse t1 prj' ac idx val)
+
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let (arrindex', idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ref
+ linindex = toLinearIndex arrsh arrindex
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val
+
+accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s ()
accumAddSparse typ prj ref idx val = case (typ, prj) of
(_, SAPHere) -> accumAddDense typ ref val