diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 20:37:06 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 20:38:05 +0200 | 
| commit | d0eb9a1edfb4233d557d954f46685f25382234d8 (patch) | |
| tree | 04eb5a746258fcaa2a3b98228c6eadb2b0178ba3 /src/ForwardAD.hs | |
| parent | 4ad7eaba73d5fda8ff5028d1e53966f728d704d3 (diff) | |
Reorder TLEither to after TEither
Diffstat (limited to 'src/ForwardAD.hs')
| -rw-r--r-- | src/ForwardAD.hs | 48 | 
1 files changed, 24 insertions, 24 deletions
| diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 5756f96..b353def 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -26,10 +26,10 @@ type family Tan t where    Tan TNil = TNil    Tan (TPair a b) = TPair (Tan a) (Tan b)    Tan (TEither a b) = TEither (Tan a) (Tan b) +  Tan (TLEither a b) = TLEither (Tan a) (Tan b)    Tan (TMaybe t) = TMaybe (Tan t)    Tan (TArr n t) = TArr n (Tan t)    Tan (TScal t) = TanS t -  Tan (TLEither a b) = TLEither (Tan a) (Tan b)  type family TanS t where    TanS TI32 = TNil @@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t)  tanty STNil = STNil  tanty (STPair a b) = STPair (tanty a) (tanty b)  tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b)  tanty (STMaybe t) = STMaybe (tanty t)  tanty (STArr n t) = STArr n (tanty t)  tanty (STScal t) = case t of @@ -55,7 +56,6 @@ tanty (STScal t) = case t of    STF64 -> STScal STF64    STBool -> STNil  tanty STAccum{} = error "Accumulators not allowed in input program" -tanty (STLEither a b) = STLEither (tanty a) (tanty b)  tanenv :: SList STy env -> SList STy (TanE env)  tanenv SNil = SNil @@ -66,6 +66,9 @@ zeroTan STNil () = ()  zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)  zeroTan (STEither a _) (Left x) = Left (zeroTan a x)  zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))  zeroTan (STMaybe _) Nothing = Nothing  zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)  zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0  zeroTan (STScal STF64) _ = 0.0  zeroTan (STScal STBool) _ = ()  zeroTan STAccum{} _ = error "Accumulators not allowed in input program" -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))  tanScalars :: STy t -> Rep (Tan t) -> [Double]  tanScalars STNil () = []  tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y  tanScalars (STEither a _) (Left x) = tanScalars a x  tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y  tanScalars (STMaybe _) Nothing = []  tanScalars (STMaybe t) (Just x) = tanScalars t x  tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x @@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x]  tanScalars (STScal STF64) x = [x]  tanScalars (STScal STBool) _ = []  tanScalars STAccum{} _ = error "Accumulators not allowed in input program" -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y  tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]  tanEScalars SNil SNil = [] @@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) =  unzipDN (STEither a b) d = case d of    Left d1 -> bimap Left Left (unzipDN a d1)    Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of +  Nothing -> (Nothing, Nothing) +  Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) +  Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)  unzipDN (STMaybe t) d = case d of    Nothing -> (Nothing, Nothing)    Just d' -> bimap Just Just (unzipDN t d') @@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of    STF64 -> d    STBool -> (d, ())  unzipDN STAccum{} _ = error "Accumulators not allowed in input program" -unzipDN (STLEither a b) d = case d of -  Nothing -> (Nothing, Nothing) -  Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) -  Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)  dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double  dotprodTan STNil _ _ = 0.0 @@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of    (Left x', Left y') -> dotprodTan a x' y'    (Right x', Right y') -> dotprodTan b x' y'    _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of +  (Nothing, _) -> 0.0  -- 0 * y = 0 +  (_, Nothing) -> 0.0  -- x * 0 = 0 +  (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' +  (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' +  _ -> error "dotprodTan: incompatible LEither alternatives"  dotprodTan (STMaybe t) x y = case (x, y) of    (Nothing, Nothing) -> 0.0    (Just x', Just y') -> dotprodTan t x' y' @@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of    STF64 -> x * y    STBool -> 0.0  dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -dotprodTan (STLEither a b) x y = case (x, y) of -  (Nothing, _) -> 0.0  -- 0 * y = 0 -  (_, Nothing) -> 0.0  -- x * 0 = 0 -  (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' -  (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' -  _ -> error "dotprodTan: incompatible LEither alternatives"  -- -- Primal expression must be duplicable  -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t)  dnConst STNil = const ()  dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)  dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))  dnConst (STMaybe t) = fmap (dnConst t)  dnConst (STArr _ t) = arrayMap (dnConst t)  dnConst (STScal t) = case t of @@ -196,7 +197,6 @@ dnConst (STScal t) = case t of    STF64 -> (,0.0)    STBool -> id  dnConst STAccum{} = error "Accumulators not allowed in input program" -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))  -- | Given a function that computes the forward derivative for a particular  -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e =    case e of      Left x -> \f -> Left (dnOnehots t1 x (f . Left))      Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = +  case e of +    Nothing -> \_ -> Nothing +    Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) +    Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))  dnOnehots (STMaybe t) m =    case m of      Nothing -> \_ -> Nothing @@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of    STF64 -> \f -> f (x, 1.0)    STBool -> \_ -> ()  dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" -dnOnehots (STLEither t1 t2) e = -  case e of -    Nothing -> \_ -> Nothing -    Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) -    Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))  dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)  dnConstEnv SNil SNil = SNil | 
