diff options
Diffstat (limited to 'src/CHAD')
-rw-r--r-- | src/CHAD/Accum.hs | 52 | ||||
-rw-r--r-- | src/CHAD/EnvDescr.hs | 53 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 63 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 85 | ||||
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 23 |
5 files changed, 189 insertions, 87 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index b61b5ff..7212232 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -1,18 +1,54 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD and CHAD.Top. module CHAD.Accum where import AST import CHAD.Types import Data +import AST.Env +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e = - makeAccumulators envpro $ - EWith ext t (EZero ext t) e +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) @@ -25,3 +61,7 @@ uninvertTup (t `SCons` list) tcore e = (ESnd ext (EVar ext recT IZ)) (ESnd ext (EFst ext (EVar ext recT IZ)))) +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs index fcd91f7..49ae0e6 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/EnvDescr.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} @@ -9,10 +10,13 @@ module CHAD.EnvDescr where import Data.Kind (Type) +import Data.Some import GHC.TypeLits (Symbol) +import Analysis.Identity (ValId(..)) import AST.Env import AST.Types +import AST.Weaken import CHAD.Types import Data @@ -27,12 +31,17 @@ deriving instance Show (Storage s) -- | Environment description data Descr env sto where DTop :: Descr '[] '[] - DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) + DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto) deriving instance Show (Descr env sto) descrList :: Descr env sto -> SList STy env descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des +descrList (des `DPush` (t, _, _)) = t `SCons` descrList des + +descrPrj :: Descr env sto -> Idx env t -> (STy t, Maybe (ValId t), Some Storage) +descrPrj (_ `DPush` (ty, vid, sto)) IZ = (ty, vid, Some sto) +descrPrj (des `DPush` _) (IS i) = descrPrj des i +descrPrj DTop i = case i of {} -- | This could have more precise typing on the output storage. subDescr :: Descr env sto -> Subenv env env' @@ -43,13 +52,13 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) +subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) @@ -64,12 +73,24 @@ type family Select env sto s where select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SDiscr (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -select s@SDiscr (DPush des (_, SMerge)) = select s des -select s@SAccum (DPush des (_, SDiscr)) = select s des -select s@SMerge (DPush des (_, SDiscr)) = select s des -select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) +select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, _, SAccum)) = select s des +select s@SDiscr (DPush des (_, _, SAccum)) = select s des +select s@SAccum (DPush des (_, _, SMerge)) = select s des +select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, _, SMerge)) = select s des +select s@SAccum (DPush des (_, _, SDiscr)) = select s des +select s@SMerge (DPush des (_, _, SDiscr)) = select s des +select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index d058132..484779e 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -10,13 +10,18 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Top where +import Analysis.Identity import AST +import AST.Env +import AST.Sparse +import AST.SplitLets import AST.Weaken.Auto import CHAD import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data +import qualified Data.VarMap as VarMap type family MergeEnv env where @@ -25,7 +30,7 @@ type family MergeEnv env where mergeDescr :: SList STy env -> Descr env (MergeEnv env) mergeDescr SNil = DTop -mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge) +mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] mergeEnvNoAccum SNil = Refl @@ -38,38 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, SAccum)) - else k (des `DPush` (t, SMerge)) - -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl + if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto + -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, SAccum)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) - (ESnd ext (EVar ext (typeOf e) IZ)))) - (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, SMerge)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) - (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) - (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) @@ -79,21 +71,24 @@ chad config env (term :: Ex env t) let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (select SAccum descr) $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr term)) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') + where + term' = identityAnalysis env (splitLets term) chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) chad' config env term diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 7f49cef..44ac20e 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -1,8 +1,10 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Types where +import AST.Accum import AST.Types import Data @@ -11,16 +13,18 @@ type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) + D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TMaybe (TArr n (D2 t)) + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where @@ -40,12 +44,13 @@ type family D2E env where type family D2AcE env where D2AcE '[] = '[] - D2AcE (t : env) = TAccum t : D2AcE env + D2AcE (t : env) = TAccum (D2 t) : D2AcE env d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t @@ -55,27 +60,34 @@ d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil d1e (t `SCons` env) = d1 t `SCons` d1e env +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTPair (d2M a) (d2M b) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTArr n (d2M t) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" + d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) -d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STMaybe (STArr n (d2 t)) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (t `SCons` ts) = d2 t `SCons` d2e ts +d2e = slistMap fromSMTy . d2eM d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts data CHADConfig = CHADConfig @@ -85,6 +97,8 @@ data CHADConfig = CHADConfig chcCaseArrayAccum :: Bool , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool } deriving (Show) @@ -93,12 +107,14 @@ defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False , chcArgArrayAccum = False + , chcSmartWith = False } chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True - , chcArgArrayAccum = True } + , chcArgArrayAccum = True + , chcSmartWith = True } ------------------------------------ LEMMAS ------------------------------------ @@ -106,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index f843206..888fed4 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -19,24 +19,25 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" |