diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-16 23:21:55 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-16 23:21:55 +0200 |
commit | 2b1a40b5933b8b0dceaae744e5b70cb604822c9d (patch) | |
tree | 652d6d88efd2b0b4502819297333305cec5242c4 /src/Interpreter.hs | |
parent | eed0f2999d6f6c8485ef53deb38f9d0a67b4f88e (diff) |
CHAD.hs compiles
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 803a24a..b3576ce 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -162,7 +162,7 @@ interpret'Rec env = \case idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t p accum idx val + accumAddSparseD t p accum idx val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi @@ -239,7 +239,7 @@ addM typ a b = case typ of | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b -onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) @@ -274,7 +274,7 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) +newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a) newAcSparse typ prj idx val = case (typ, prj) of (_, SAPHere) -> newAcDense typ val @@ -291,9 +291,9 @@ newAcSparse typ prj idx val = case (typ, prj) of (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ziarr @@ -329,7 +329,34 @@ accumAddDense typ ref val = case typ of accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s () +accumAddSparseD typ prj ref idx val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) + (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) + (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (newAcSparse t1 prj' idx val) + (\ac -> accumAddSparse t1 prj' ac idx val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s () accumAddSparse typ prj ref idx val = case (typ, prj) of (_, SAPHere) -> accumAddDense typ ref val |