summaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs28
1 files changed, 28 insertions, 0 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b7036dd..5756f96 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -29,6 +29,7 @@ type family Tan t where
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
type family TanS t where
TanS TI32 = TNil
@@ -54,6 +55,11 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
+
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
@@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
@@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -174,6 +196,7 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil