diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:54:12 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:54:12 +0200 |
commit | 3fd8d35cca2a23c137934a170c67e8ce310edf13 (patch) | |
tree | 429fb99f9c1395272f1f9a94bfbc0e003fa39b21 /src/ForwardAD.hs | |
parent | 919a36f8eed21501357185a90e2b7a4d9eaf7f08 (diff) |
Complete monoidal accumulator rewrite
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r-- | src/ForwardAD.hs | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b7036dd..5756f96 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -29,6 +29,7 @@ type family Tan t where Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t + Tan (TLEither a b) = TLEither (Tan a) (Tan b) type family TanS t where TanS TI32 = TNil @@ -54,6 +55,11 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanty (STLEither a b) = STLEither (tanty a) (tanty b) + +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env zeroTan :: STy t -> Rep t -> Rep (Tan t) zeroTan STNil () = () @@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] @@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -174,6 +196,7 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil |