From 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 27 Nov 2025 21:30:17 +0100 Subject: WIP user-specified custom types The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus. --- src/CHAD/ForwardAD.hs | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'src/CHAD/ForwardAD.hs') diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs index 0ae88ce..933a259 100644 --- a/src/CHAD/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -57,6 +57,7 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanty STUser{} = error "User types not yet supported in forward AD" tanenv :: SList STy env -> SList STy (TanE env) tanenv SNil = SNil @@ -79,6 +80,7 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan STUser{} _ = error "User types not yet supported in forward AD" tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] @@ -97,6 +99,7 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars STUser{} _ = [] tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -128,6 +131,7 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN STUser{} _ = error "User types not yet supported in forward AD" dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -160,6 +164,7 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan STUser{} _ _ = 0.0 -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -198,6 +203,7 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst STUser{} = error "User types not yet supported in forward AD" -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -233,6 +239,7 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots ty@STUser{} x = const (zeroTan ty x) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil -- cgit v1.2.3-70-g09d2