diff options
Diffstat (limited to 'src/CHAD/Drev')
| -rw-r--r-- | src/CHAD/Drev/Accum.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 8 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs | 1 |
3 files changed, 10 insertions, 1 deletions
diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs index 6f25f11..43305e6 100644 --- a/src/CHAD/Drev/Accum.hs +++ b/src/CHAD/Drev/Accum.hs @@ -21,6 +21,7 @@ d2zeroInfo STMaybe{} _ = ENil ext d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" +d2zeroInfo (STUser t) e = euserD2ZeroInfo t (EUnUser ext e) d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) d2deepZeroInfo STNil _ = ENil ext @@ -43,6 +44,7 @@ d2deepZeroInfo (STMaybe a) e = d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" +d2deepZeroInfo (STUser t) e = euserD2DeepZeroInfo t (EUnUser ext e) -- The weakening is necessary because we need to initialise the created -- accumulators with zeros. Those zeros are deep and need full primals. This diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs index 367a974..e119de2 100644 --- a/src/CHAD/Drev/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -4,7 +4,8 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Drev.Types where -import CHAD.AST.Accum +import Data.Proxy + import CHAD.AST.Types import CHAD.Data @@ -17,6 +18,7 @@ type family D1 t where D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t + D1 (TUser t) = TUser t type family D2 t where D2 TNil = TNil @@ -26,6 +28,7 @@ type family D2 t where D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t + D2 (TUser t) = TUser (UserD2 t) type family D2s t where D2s TI32 = TNil @@ -55,6 +58,7 @@ d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1 (STUser t) = STUser t d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil @@ -74,6 +78,7 @@ d2M (STScal t) = case t of STF64 -> SMTScal STF64 STBool -> SMTNil d2M STAccum{} = error "Accumulators not allowed in input program" +d2M (STUser _) = SMTUser Proxy d2 :: STy t -> STy (D2 t) d2 = fromSMTy . d2M @@ -147,6 +152,7 @@ d1Identity = \case STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" + STUser{} -> Refl d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs index 019119c..51403f5 100644 --- a/src/CHAD/Drev/Types/ToTan.hs +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -41,3 +41,4 @@ toTan typ primal der = case typ of STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" + STUser{} -> error "User types not yet supported in forward AD" |
