aboutsummaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs30
1 files changed, 29 insertions, 1 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index af35f91..b353def 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -26,6 +26,7 @@ type family Tan t where
Tan TNil = TNil
Tan (TPair a b) = TPair (Tan a) (Tan b)
Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
@@ -45,6 +46,7 @@ tanty :: STy t -> STy (Tan t)
tanty STNil = STNil
tanty (STPair a b) = STPair (tanty a) (tanty b)
tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
tanty (STMaybe t) = STMaybe (tanty t)
tanty (STArr n t) = STArr n (tanty t)
tanty (STScal t) = case t of
@@ -55,11 +57,18 @@ tanty (STScal t) = case t of
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
+
zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
zeroTan (STMaybe _) Nothing = Nothing
zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
zeroTan (STArr _ t) x = fmap (zeroTan t) x
@@ -75,6 +84,9 @@ tanScalars STNil () = []
tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
tanScalars (STEither a _) (Left x) = tanScalars a x
tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanScalars (STMaybe _) Nothing = []
tanScalars (STMaybe t) (Just x) = tanScalars t x
tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
@@ -98,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) =
unzipDN (STEither a b) d = case d of
Left d1 -> bimap Left Left (unzipDN a d1)
Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
unzipDN (STMaybe t) d = case d of
Nothing -> (Nothing, Nothing)
Just d' -> bimap Just Just (unzipDN t d')
@@ -120,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of
(Left x', Left y') -> dotprodTan a x' y'
(Right x', Right y') -> dotprodTan b x' y'
_ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
dotprodTan (STMaybe t) x y = case (x, y) of
(Nothing, Nothing) -> 0.0
(Just x', Just y') -> dotprodTan t x' y'
@@ -165,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t)
dnConst STNil = const ()
dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
dnConst (STMaybe t) = fmap (dnConst t)
dnConst (STArr _ t) = arrayMap (dnConst t)
dnConst (STScal t) = case t of
@@ -188,6 +211,11 @@ dnOnehots (STEither t1 t2) e =
case e of
Left x -> \f -> Left (dnOnehots t1 x (f . Left))
Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnOnehots (STMaybe t) m =
case m of
Nothing -> \_ -> Nothing
@@ -223,7 +251,7 @@ data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (D
makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
makeFwdADArtifactInterp env expr =
let dexpr = dfwdDN expr
- in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
{-# NOINLINE makeFwdADArtifactCompile #-}
makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)