diff options
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r-- | src/ForwardAD.hs | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 5756f96..b353def 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -26,10 +26,10 @@ type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t - Tan (TLEither a b) = TLEither (Tan a) (Tan b) type family TanS t where TanS TI32 = TNil @@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t) tanty STNil = STNil tanty (STPair a b) = STPair (tanty a) (tanty b) tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanty (STMaybe t) = STMaybe (tanty t) tanty (STArr n t) = STArr n (tanty t) tanty (STScal t) = case t of @@ -55,7 +56,6 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" -tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanenv :: SList STy env -> SList STy (TanE env) tanenv SNil = SNil @@ -66,6 +66,9 @@ zeroTan STNil () = () zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) zeroTan (STEither a _) (Left x) = Left (zeroTan a x) zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) zeroTan (STMaybe _) Nothing = Nothing zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y tanScalars (STEither a _) (Left x) = tanScalars a x tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x @@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) = unzipDN (STEither a b) d = case d of Left d1 -> bimap Left Left (unzipDN a d1) Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) unzipDN (STMaybe t) d = case d of Nothing -> (Nothing, Nothing) Just d' -> bimap Just Just (unzipDN t d') @@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" -unzipDN (STLEither a b) d = case d of - Nothing -> (Nothing, Nothing) - Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) - Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of (Left x', Left y') -> dotprodTan a x' y' (Right x', Right y') -> dotprodTan b x' y' _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" dotprodTan (STMaybe t) x y = case (x, y) of (Nothing, Nothing) -> 0.0 (Just x', Just y') -> dotprodTan t x' y' @@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -dotprodTan (STLEither a b) x y = case (x, y) of - (Nothing, _) -> 0.0 -- 0 * y = 0 - (_, Nothing) -> 0.0 -- x * 0 = 0 - (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' - (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible LEither alternatives" -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t) dnConst STNil = const () dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) dnConst (STMaybe t) = fmap (dnConst t) dnConst (STArr _ t) = arrayMap (dnConst t) dnConst (STScal t) = case t of @@ -196,7 +197,6 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e = case e of Left x -> \f -> Left (dnOnehots t1 x (f . Left)) Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnOnehots (STMaybe t) m = case m of Nothing -> \_ -> Nothing @@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" -dnOnehots (STLEither t1 t2) e = - case e of - Nothing -> \_ -> Nothing - Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) - Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil |