aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Drev')
-rw-r--r--src/CHAD/Drev/Accum.hs72
-rw-r--r--src/CHAD/Drev/EnvDescr.hs96
-rw-r--r--src/CHAD/Drev/Top.hs96
-rw-r--r--src/CHAD/Drev/Types.hs153
-rw-r--r--src/CHAD/Drev/Types/ToTan.hs43
5 files changed, 460 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs
new file mode 100644
index 0000000..6f25f11
--- /dev/null
+++ b/src/CHAD/Drev/Accum.hs
@@ -0,0 +1,72 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+-- | TODO this module is a grab-bag of random utility functions that are shared
+-- between CHAD.Drev and CHAD.Drev.Top.
+module CHAD.Drev.Accum where
+
+import CHAD.AST
+import CHAD.Data
+import CHAD.Drev.Types
+import CHAD.AST.Env
+
+
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
+d2deepZeroInfo STNil _ = ENil ext
+d2deepZeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
+d2deepZeroInfo (STEither a b) e =
+ ECase ext e
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STLEither a b) e =
+ elcase e
+ (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STMaybe a) e =
+ emaybe e
+ (ENothing ext (tDeepZeroInfo (d2M a)))
+ (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
+d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
+d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+-- The weakening is necessary because we need to initialise the created
+-- accumulators with zeros. Those zeros are deep and need full primals. This
+-- means, in the end, that primals corresponding to environment entries
+-- promoted to an accumulator with accumPromote in CHAD need to be stored for
+-- the dual.
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
+
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
diff --git a/src/CHAD/Drev/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs
new file mode 100644
index 0000000..5a90303
--- /dev/null
+++ b/src/CHAD/Drev/EnvDescr.hs
@@ -0,0 +1,96 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Drev.EnvDescr where
+
+import Data.Kind (Type)
+import Data.Some
+import GHC.TypeLits (Symbol)
+
+import CHAD.Analysis.Identity (ValId(..))
+import CHAD.AST.Env
+import CHAD.AST.Types
+import CHAD.AST.Weaken
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+type Storage :: Symbol -> Type
+data Storage s where
+ SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
+ SMerge :: Storage "merge" -- ^ just return and merge
+ SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+ DTop :: Descr '[] '[]
+ DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _, _)) = t `SCons` descrList des
+
+descrPrj :: Descr env sto -> Idx env t -> (STy t, Maybe (ValId t), Some Storage)
+descrPrj (_ `DPush` (ty, vid, sto)) IZ = (ty, vid, Some sto)
+descrPrj (des `DPush` _) (IS i) = descrPrj des i
+descrPrj DTop i = case i of {}
+
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e)
+ SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e)
+ SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e)
+subDescr (des `DPush` (_, _, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+ SDiscr -> k des' submerge subaccum (SENo subd1e)
+
+-- | Select only the types from the environment that have the specified storage
+type family Select env sto s where
+ Select '[] '[] _ = '[]
+ Select (t : ts) (s : sto) s = t : Select ts sto s
+ Select (_ : ts) (_ : sto) s = Select ts sto s
+
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, _, SAccum)) = select s des
+select s@SDiscr (DPush des (_, _, SAccum)) = select s des
+select s@SAccum (DPush des (_, _, SMerge)) = select s des
+select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des)
+select s@SDiscr (DPush des (_, _, SMerge)) = select s des
+select s@SAccum (DPush des (_, _, SDiscr)) = select s des
+select s@SMerge (DPush des (_, _, SDiscr)) = select s des
+select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des)
+
+selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s)
+selectSub _ DTop = SETop
+selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des)
diff --git a/src/CHAD/Drev/Top.hs b/src/CHAD/Drev/Top.hs
new file mode 100644
index 0000000..510e73e
--- /dev/null
+++ b/src/CHAD/Drev/Top.hs
@@ -0,0 +1,96 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Drev.Top where
+
+import CHAD.Analysis.Identity
+import CHAD.AST
+import CHAD.AST.Env
+import CHAD.AST.Sparse
+import CHAD.AST.SplitLets
+import CHAD.AST.Weaken.Auto
+import CHAD.Data
+import qualified CHAD.Data.VarMap as VarMap
+import CHAD.Drev
+import CHAD.Drev.Accum
+import CHAD.Drev.EnvDescr
+import CHAD.Drev.Types
+
+
+type family MergeEnv env where
+ MergeEnv '[] = '[]
+ MergeEnv (t : ts) = "merge" : MergeEnv ts
+
+mergeDescr :: SList STy env -> Descr env (MergeEnv env)
+mergeDescr SNil = DTop
+mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge)
+
+mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
+mergeEnvNoAccum SNil = Refl
+mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl
+
+mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
+mergeEnvOnlyMerge SNil = Refl
+mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+
+accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
+accumDescr SNil k = k DTop
+accumDescr (t `SCons` env) k = accumDescr env $ \des ->
+ if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum))
+ else k (des `DPush` (t, Nothing, SMerge))
+
+reassembleD2E :: Descr env sto
+ -> D1E env :> env'
+ -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
+ -> Ex env' (Tup (D2E env))
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
+chad config env (term :: Ex env t)
+ | True <- chcArgArrayAccum config
+ = let ?config = config
+ in accumDescr env $ \descr ->
+ let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
+ tvar = STPair t1 (tTup (d2e (select SAccum descr)))
+ in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #acenv (d2ace (select SAccum descr))
+ &. #tl (d1e env))
+ (#d :++: #acenv :++: #tl)
+ (#acenv :++: #d :++: #tl)) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
+ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+
+ | False <- chcArgArrayAccum config
+ , Refl <- mergeEnvNoAccum env
+ , Refl <- mergeEnvOnlyMerge env
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
+ where
+ term' = identityAnalysis env (splitLets term)
+
+chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+chad' config env term
+ | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
+ = chad config env term
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs
new file mode 100644
index 0000000..367a974
--- /dev/null
+++ b/src/CHAD/Drev/Types.hs
@@ -0,0 +1,153 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Drev.Types where
+
+import CHAD.AST.Accum
+import CHAD.AST.Types
+import CHAD.Data
+
+
+type family D1 t where
+ D1 TNil = TNil
+ D1 (TPair a b) = TPair (D1 a) (D1 b)
+ D1 (TEither a b) = TEither (D1 a) (D1 b)
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
+ D1 (TMaybe a) = TMaybe (D1 a)
+ D1 (TArr n t) = TArr n (D1 t)
+ D1 (TScal t) = TScal t
+
+type family D2 t where
+ D2 TNil = TNil
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
+ D2 (TEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TMaybe t) = TMaybe (D2 t)
+ D2 (TArr n t) = TArr n (D2 t)
+ D2 (TScal t) = D2s t
+
+type family D2s t where
+ D2s TI32 = TNil
+ D2s TI64 = TNil
+ D2s TF32 = TScal TF32
+ D2s TF64 = TScal TF64
+ D2s TBool = TNil
+
+type family D1E env where
+ D1E '[] = '[]
+ D1E (t : env) = D1 t : D1E env
+
+type family D2E env where
+ D2E '[] = '[]
+ D2E (t : env) = D2 t : D2E env
+
+type family D2AcE env where
+ D2AcE '[] = '[]
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
+
+d1 :: STy t -> STy (D1 t)
+d1 STNil = STNil
+d1 (STPair a b) = STPair (d1 a) (d1 b)
+d1 (STEither a b) = STEither (d1 a) (d1 b)
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
+d1 (STMaybe t) = STMaybe (d1 t)
+d1 (STArr n t) = STArr n (d1 t)
+d1 (STScal t) = STScal t
+d1 STAccum{} = error "Accumulators not allowed in input program"
+
+d1e :: SList STy env -> SList STy (D1E env)
+d1e SNil = SNil
+d1e (t `SCons` env) = d1 t `SCons` d1e env
+
+d2M :: STy t -> SMTy (D2 t)
+d2M STNil = SMTNil
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
+d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STMaybe t) = SMTMaybe (d2M t)
+d2M (STArr n t) = SMTArr n (d2M t)
+d2M (STScal t) = case t of
+ STI32 -> SMTNil
+ STI64 -> SMTNil
+ STF32 -> SMTScal STF32
+ STF64 -> SMTScal STF64
+ STBool -> SMTNil
+d2M STAccum{} = error "Accumulators not allowed in input program"
+
+d2 :: STy t -> STy (D2 t)
+d2 = fromSMTy . d2M
+
+d2eM :: SList STy env -> SList SMTy (D2E env)
+d2eM SNil = SNil
+d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts
+
+d2e :: SList STy env -> SList STy (D2E env)
+d2e = slistMap fromSMTy . d2eM
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts
+
+
+data CHADConfig = CHADConfig
+ { -- | D[let] will bind variables containing arrays in accumulator mode.
+ chcLetArrayAccum :: Bool
+ , -- | D[case] will bind variables containing arrays in accumulator mode.
+ chcCaseArrayAccum :: Bool
+ , -- | Introduce top-level arguments containing arrays in accumulator mode.
+ chcArgArrayAccum :: Bool
+ , -- | Place with-blocks around array variable scopes, and redirect accumulations there.
+ chcSmartWith :: Bool
+ }
+ deriving (Show)
+
+defaultConfig :: CHADConfig
+defaultConfig = CHADConfig
+ { chcLetArrayAccum = False
+ , chcCaseArrayAccum = False
+ , chcArgArrayAccum = False
+ , chcSmartWith = False
+ }
+
+chcSetAccum :: CHADConfig -> CHADConfig
+chcSetAccum c = c { chcLetArrayAccum = True
+ , chcCaseArrayAccum = True
+ , chcArgArrayAccum = True
+ , chcSmartWith = True }
+
+
+------------------------------------ LEMMAS ------------------------------------
+
+indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
+indexTupD1Id SZ = Refl
+indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl
+
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs
new file mode 100644
index 0000000..019119c
--- /dev/null
+++ b/src/CHAD/Drev/Types/ToTan.hs
@@ -0,0 +1,43 @@
+{-# LANGUAGE GADTs #-}
+module CHAD.Drev.Types.ToTan where
+
+import Data.Bifunctor (bimap)
+
+import CHAD.Array
+import CHAD.AST.Types
+import CHAD.Data
+import CHAD.Drev.Types
+import CHAD.ForwardAD
+import CHAD.Interpreter.Rep
+
+
+toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
+toTanE SNil SNil SNil = SNil
+toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
+ Value (toTan t p x) `SCons` toTanE env primal inp
+
+toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
+toTan typ primal der = case typ of
+ STNil -> der
+ STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal
+ STEither t1 t2 -> case der of
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just d -> case (primal, d) of
+ (Left p, Left d') -> Left (toTan t1 p d')
+ (Right p, Right d') -> Right (toTan t2 p d')
+ _ -> error "Primal and cotangent disagree on Either alternative"
+ STLEither t1 t2 -> case (primal, der) of
+ (_, Nothing) -> Nothing
+ (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
+ (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
+ _ -> error "Primal and cotangent disagree on LEither alternative"
+ STMaybe t -> liftA2 (toTan t) primal der
+ STArr _ t
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
+ STScal sty -> case sty of
+ STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
+ STAccum{} -> error "Accumulators not allowed in input program"