diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
| commit | 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch) | |
| tree | a21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/ForwardAD.hs | |
| parent | ae634c056b500a568b2d89b7f8e225404a2c0c62 (diff) | |
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of
monoids to be elementwise float addition; this fundamentally clashes
with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/ForwardAD.hs')
| -rw-r--r-- | src/CHAD/ForwardAD.hs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs index 0ae88ce..933a259 100644 --- a/src/CHAD/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -57,6 +57,7 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanty STUser{} = error "User types not yet supported in forward AD" tanenv :: SList STy env -> SList STy (TanE env) tanenv SNil = SNil @@ -79,6 +80,7 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan STUser{} _ = error "User types not yet supported in forward AD" tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] @@ -97,6 +99,7 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars STUser{} _ = [] tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -128,6 +131,7 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN STUser{} _ = error "User types not yet supported in forward AD" dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -160,6 +164,7 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan STUser{} _ _ = 0.0 -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -198,6 +203,7 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst STUser{} = error "User types not yet supported in forward AD" -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -233,6 +239,7 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots ty@STUser{} x = const (zeroTan ty x) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil |
