diff options
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r-- | src/ForwardAD.hs | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index af35f91..b353def 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -26,6 +26,7 @@ type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t @@ -45,6 +46,7 @@ tanty :: STy t -> STy (Tan t) tanty STNil = STNil tanty (STPair a b) = STPair (tanty a) (tanty b) tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanty (STMaybe t) = STMaybe (tanty t) tanty (STArr n t) = STArr n (tanty t) tanty (STScal t) = case t of @@ -55,11 +57,18 @@ tanty (STScal t) = case t of STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env + zeroTan :: STy t -> Rep t -> Rep (Tan t) zeroTan STNil () = () zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) zeroTan (STEither a _) (Left x) = Left (zeroTan a x) zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) zeroTan (STMaybe _) Nothing = Nothing zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,6 +84,9 @@ tanScalars STNil () = [] tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y tanScalars (STEither a _) (Left x) = tanScalars a x tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x @@ -98,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) = unzipDN (STEither a b) d = case d of Left d1 -> bimap Left Left (unzipDN a d1) Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) unzipDN (STMaybe t) d = case d of Nothing -> (Nothing, Nothing) Just d' -> bimap Just Just (unzipDN t d') @@ -120,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of (Left x', Left y') -> dotprodTan a x' y' (Right x', Right y') -> dotprodTan b x' y' _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" dotprodTan (STMaybe t) x y = case (x, y) of (Nothing, Nothing) -> 0.0 (Just x', Just y') -> dotprodTan t x' y' @@ -165,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t) dnConst STNil = const () dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) dnConst (STMaybe t) = fmap (dnConst t) dnConst (STArr _ t) = arrayMap (dnConst t) dnConst (STScal t) = case t of @@ -188,6 +211,11 @@ dnOnehots (STEither t1 t2) e = case e of Left x -> \f -> Left (dnOnehots t1 x (f . Left)) Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnOnehots (STMaybe t) m = case m of Nothing -> \_ -> Nothing @@ -223,7 +251,7 @@ data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (D makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t makeFwdADArtifactInterp env expr = let dexpr = dfwdDN expr - in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr) + in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) {-# NOINLINE makeFwdADArtifactCompile #-} makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) |