aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/ForwardAD.hs')
-rw-r--r--src/CHAD/ForwardAD.hs7
1 files changed, 7 insertions, 0 deletions
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs
index 0ae88ce..933a259 100644
--- a/src/CHAD/ForwardAD.hs
+++ b/src/CHAD/ForwardAD.hs
@@ -57,6 +57,7 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanty STUser{} = error "User types not yet supported in forward AD"
tanenv :: SList STy env -> SList STy (TanE env)
tanenv SNil = SNil
@@ -79,6 +80,7 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+zeroTan STUser{} _ = error "User types not yet supported in forward AD"
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
@@ -97,6 +99,7 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars STUser{} _ = []
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -128,6 +131,7 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+unzipDN STUser{} _ = error "User types not yet supported in forward AD"
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -160,6 +164,7 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+dotprodTan STUser{} _ _ = 0.0
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -198,6 +203,7 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
+dnConst STUser{} = error "User types not yet supported in forward AD"
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -233,6 +239,7 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+dnOnehots ty@STUser{} x = const (zeroTan ty x)
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil