aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Interpreter.hs')
-rw-r--r--src/CHAD/Interpreter.hs14
1 files changed, 14 insertions, 0 deletions
diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs
index 6410b5b..8aa02d7 100644
--- a/src/CHAD/Interpreter.hs
+++ b/src/CHAD/Interpreter.hs
@@ -227,6 +227,8 @@ interpret'Rec env = \case
b' <- interpret' env b
return $ onehotM p t a' b'
EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
+ EUser _ _ e -> interpret' env e
+ EUnUser _ e -> interpret' env e
interpretOp :: SOp a t -> Rep a -> Rep t
interpretOp op arg = case op of
@@ -267,6 +269,9 @@ zeroM typ zi = case typ of
STI64 -> 0
STF32 -> 0.0
STF64 -> 0.0
+ SMTUser t ->
+ interpretOpen False (userZeroInfo t `SCons` SNil) (Value zi `SCons` SNil)
+ (euserZero t (EVar ext (userZeroInfo t) IZ))
deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
deepZeroM typ zi = case typ of
@@ -280,6 +285,9 @@ deepZeroM typ zi = case typ of
STI64 -> 0
STF32 -> 0.0
STF64 -> 0.0
+ SMTUser t ->
+ interpretOpen False (userDeepZeroInfo t `SCons` SNil) (Value zi `SCons` SNil)
+ (euserDeepZero t (EVar ext (userDeepZeroInfo t) IZ))
addM :: SMTy t -> Rep t -> Rep t -> Rep t
addM typ a b = case typ of
@@ -303,6 +311,9 @@ addM typ a b = case typ of
| sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i))
| otherwise -> error "Plus of inconsistently shaped arrays"
SMTScal sty -> numericIsNum sty $ a + b
+ SMTUser t ->
+ interpretOpen False (userRepTy t `SCons` userRepTy t `SCons` SNil) (Value a `SCons` Value b `SCons` SNil)
+ (euserPlus t (EVar ext (userRepTy t) IZ) (EVar ext (userRepTy t) (IS IZ)))
onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
onehotM SAPHere _ _ val = val
@@ -329,6 +340,7 @@ newAcDense typ val = case typ of
SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
+ SMTUser _ -> newIORef val
onehotArray :: Monad m
=> (Rep (AcIdxS p a) -> m v) -- ^ the "one"
@@ -348,6 +360,7 @@ readAc typ val = case typ of
SMTMaybe t -> traverse (readAc t) =<< readIORef val
SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
+ SMTUser _ -> readIORef val
accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
@@ -408,6 +421,7 @@ accumAddDense typ ref sp val = case (typ, sp) of
forM_ [0 .. arraySize ref - 1] $ \i ->
accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
(SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+ (SMTUser t, SpUser) -> AcM $ atomicModifyIORef' ref (\x -> (addM (SMTUser t) x val, ()))
-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()