aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/ForwardAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/ForwardAD.hs
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/ForwardAD.hs')
-rw-r--r--src/CHAD/ForwardAD.hs7
1 files changed, 7 insertions, 0 deletions
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs
index 0ae88ce..933a259 100644
--- a/src/CHAD/ForwardAD.hs
+++ b/src/CHAD/ForwardAD.hs
@@ -57,6 +57,7 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanty STUser{} = error "User types not yet supported in forward AD"
tanenv :: SList STy env -> SList STy (TanE env)
tanenv SNil = SNil
@@ -79,6 +80,7 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+zeroTan STUser{} _ = error "User types not yet supported in forward AD"
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
@@ -97,6 +99,7 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars STUser{} _ = []
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -128,6 +131,7 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+unzipDN STUser{} _ = error "User types not yet supported in forward AD"
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -160,6 +164,7 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+dotprodTan STUser{} _ _ = 0.0
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -198,6 +203,7 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
+dnConst STUser{} = error "User types not yet supported in forward AD"
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -233,6 +239,7 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+dnOnehots ty@STUser{} x = const (zeroTan ty x)
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil