aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Drev')
-rw-r--r--src/CHAD/Drev/Accum.hs2
-rw-r--r--src/CHAD/Drev/Types.hs8
-rw-r--r--src/CHAD/Drev/Types/ToTan.hs1
3 files changed, 10 insertions, 1 deletions
diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs
index 6f25f11..43305e6 100644
--- a/src/CHAD/Drev/Accum.hs
+++ b/src/CHAD/Drev/Accum.hs
@@ -21,6 +21,7 @@ d2zeroInfo STMaybe{} _ = ENil ext
d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2zeroInfo (STUser t) e = euserD2ZeroInfo t (EUnUser ext e)
d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
d2deepZeroInfo STNil _ = ENil ext
@@ -43,6 +44,7 @@ d2deepZeroInfo (STMaybe a) e =
d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2deepZeroInfo (STUser t) e = euserD2DeepZeroInfo t (EUnUser ext e)
-- The weakening is necessary because we need to initialise the created
-- accumulators with zeros. Those zeros are deep and need full primals. This
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs
index 367a974..e119de2 100644
--- a/src/CHAD/Drev/Types.hs
+++ b/src/CHAD/Drev/Types.hs
@@ -4,7 +4,8 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Drev.Types where
-import CHAD.AST.Accum
+import Data.Proxy
+
import CHAD.AST.Types
import CHAD.Data
@@ -17,6 +18,7 @@ type family D1 t where
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
+ D1 (TUser t) = TUser t
type family D2 t where
D2 TNil = TNil
@@ -26,6 +28,7 @@ type family D2 t where
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
+ D2 (TUser t) = TUser (UserD2 t)
type family D2s t where
D2s TI32 = TNil
@@ -55,6 +58,7 @@ d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1 (STUser t) = STUser t
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
@@ -74,6 +78,7 @@ d2M (STScal t) = case t of
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
+d2M (STUser _) = SMTUser Proxy
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
@@ -147,6 +152,7 @@ d1Identity = \case
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
+ STUser{} -> Refl
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs
index 019119c..51403f5 100644
--- a/src/CHAD/Drev/Types/ToTan.hs
+++ b/src/CHAD/Drev/Types/ToTan.hs
@@ -41,3 +41,4 @@ toTan typ primal der = case typ of
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
+ STUser{} -> error "User types not yet supported in forward AD"