diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
| commit | 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch) | |
| tree | a21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/Interpreter.hs | |
| parent | ae634c056b500a568b2d89b7f8e225404a2c0c62 (diff) | |
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of
monoids to be elementwise float addition; this fundamentally clashes
with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/Interpreter.hs')
| -rw-r--r-- | src/CHAD/Interpreter.hs | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs index 6410b5b..8aa02d7 100644 --- a/src/CHAD/Interpreter.hs +++ b/src/CHAD/Interpreter.hs @@ -227,6 +227,8 @@ interpret'Rec env = \case b' <- interpret' env b return $ onehotM p t a' b' EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s + EUser _ _ e -> interpret' env e + EUnUser _ e -> interpret' env e interpretOp :: SOp a t -> Rep a -> Rep t interpretOp op arg = case op of @@ -267,6 +269,9 @@ zeroM typ zi = case typ of STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 + SMTUser t -> + interpretOpen False (userZeroInfo t `SCons` SNil) (Value zi `SCons` SNil) + (euserZero t (EVar ext (userZeroInfo t) IZ)) deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t deepZeroM typ zi = case typ of @@ -280,6 +285,9 @@ deepZeroM typ zi = case typ of STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 + SMTUser t -> + interpretOpen False (userDeepZeroInfo t `SCons` SNil) (Value zi `SCons` SNil) + (euserDeepZero t (EVar ext (userDeepZeroInfo t) IZ)) addM :: SMTy t -> Rep t -> Rep t -> Rep t addM typ a b = case typ of @@ -303,6 +311,9 @@ addM typ a b = case typ of | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b + SMTUser t -> + interpretOpen False (userRepTy t `SCons` userRepTy t `SCons` SNil) (Value a `SCons` Value b `SCons` SNil) + (euserPlus t (EVar ext (userRepTy t) IZ) (EVar ext (userRepTy t) (IS IZ))) onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val @@ -329,6 +340,7 @@ newAcDense typ val = case typ of SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val + SMTUser _ -> newIORef val onehotArray :: Monad m => (Rep (AcIdxS p a) -> m v) -- ^ the "one" @@ -348,6 +360,7 @@ readAc typ val = case typ of SMTMaybe t -> traverse (readAc t) =<< readIORef val SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val + SMTUser _ -> readIORef val accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () accumAddSparseD typ prj ref idx sp val = case (typ, prj) of @@ -408,6 +421,7 @@ accumAddDense typ ref sp val = case (typ, sp) of forM_ [0 .. arraySize ref - 1] $ \i -> accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + (SMTUser t, SpUser) -> AcM $ atomicModifyIORef' ref (\x -> (addM (SMTUser t) x val, ())) -- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () |
