diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Drev | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Drev')
| -rw-r--r-- | src/CHAD/Drev/Accum.hs | 72 | ||||
| -rw-r--r-- | src/CHAD/Drev/EnvDescr.hs | 96 | ||||
| -rw-r--r-- | src/CHAD/Drev/Top.hs | 96 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 153 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs | 43 |
5 files changed, 460 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs new file mode 100644 index 0000000..6f25f11 --- /dev/null +++ b/src/CHAD/Drev/Accum.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD.Drev and CHAD.Drev.Top. +module CHAD.Drev.Accum where + +import CHAD.AST +import CHAD.Data +import CHAD.Drev.Types +import CHAD.AST.Env + + +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/Drev/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs new file mode 100644 index 0000000..5a90303 --- /dev/null +++ b/src/CHAD/Drev/EnvDescr.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.EnvDescr where + +import Data.Kind (Type) +import Data.Some +import GHC.TypeLits (Symbol) + +import CHAD.Analysis.Identity (ValId(..)) +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +type Storage :: Symbol -> Type +data Storage s where + SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator + SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where + DTop :: Descr '[] '[] + DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _, _)) = t `SCons` descrList des + +descrPrj :: Descr env sto -> Idx env t -> (STy t, Maybe (ValId t), Some Storage) +descrPrj (_ `DPush` (ty, vid, sto)) IZ = (ty, vid, Some sto) +descrPrj (des `DPush` _) (IS i) = descrPrj des i +descrPrj DTop i = case i of {} + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) +subDescr (des `DPush` (_, _, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) + +-- | Select only the types from the environment that have the specified storage +type family Select env sto s where + Select '[] '[] _ = '[] + Select (t : ts) (s : sto) s = t : Select ts sto s + Select (_ : ts) (_ : sto) s = Select ts sto s + +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, _, SAccum)) = select s des +select s@SDiscr (DPush des (_, _, SAccum)) = select s des +select s@SAccum (DPush des (_, _, SMerge)) = select s des +select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, _, SMerge)) = select s des +select s@SAccum (DPush des (_, _, SDiscr)) = select s des +select s@SMerge (DPush des (_, _, SDiscr)) = select s des +select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Drev/Top.hs b/src/CHAD/Drev/Top.hs new file mode 100644 index 0000000..510e73e --- /dev/null +++ b/src/CHAD/Drev/Top.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Top where + +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types + + +type family MergeEnv env where + MergeEnv '[] = '[] + MergeEnv (t : ts) = "merge" : MergeEnv ts + +mergeDescr :: SList STy env -> Descr env (MergeEnv env) +mergeDescr SNil = DTop +mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) + +mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] +mergeEnvNoAccum SNil = Refl +mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl + +mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env +mergeEnvOnlyMerge SNil = Refl +mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl + +accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r +accumDescr SNil k = k DTop +accumDescr (t `SCons` env) k = accumDescr env $ \des -> + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) + +reassembleD2E :: Descr env sto + -> D1E env :> env' + -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) + -> Ex env' (Tup (D2E env)) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + +chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) +chad config env (term :: Ex env t) + | True <- chcArgArrayAccum config + = let ?config = config + in accumDescr env $ \descr -> + let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) + tvar = STPair t1 (tTup (d2e (select SAccum descr))) + in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #acenv (d2ace (select SAccum descr)) + &. #tl (d1e env)) + (#d :++: #acenv :++: #tl) + (#acenv :++: #d :++: #tl)) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ + EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) + + | False <- chcArgArrayAccum config + , Refl <- mergeEnvNoAccum env + , Refl <- mergeEnvOnlyMerge env + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') + where + term' = identityAnalysis env (splitLets term) + +chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +chad' config env term + | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) + = chad config env term diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs new file mode 100644 index 0000000..367a974 --- /dev/null +++ b/src/CHAD/Drev/Types.hs @@ -0,0 +1,153 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Types where + +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data + + +type family D1 t where + D1 TNil = TNil + D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) + D1 (TMaybe a) = TMaybe (D1 a) + D1 (TArr n t) = TArr n (D1 t) + D1 (TScal t) = TScal t + +type family D2 t where + D2 TNil = TNil + D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) + D2 (TMaybe t) = TMaybe (D2 t) + D2 (TArr n t) = TArr n (D2 t) + D2 (TScal t) = D2s t + +type family D2s t where + D2s TI32 = TNil + D2s TI64 = TNil + D2s TF32 = TScal TF32 + D2s TF64 = TScal TF64 + D2s TBool = TNil + +type family D1E env where + D1E '[] = '[] + D1E (t : env) = D1 t : D1E env + +type family D2E env where + D2E '[] = '[] + D2E (t : env) = D2 t : D2E env + +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = TAccum (D2 t) : D2AcE env + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) +d1 (STMaybe t) = STMaybe (d1 t) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STAccum{} = error "Accumulators not allowed in input program" + +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env + +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTPair (d2M a) (d2M b) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTArr n (d2M t) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts + +d2e :: SList STy env -> SList STy (D2E env) +d2e = slistMap fromSMTy . d2eM + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts + + +data CHADConfig = CHADConfig + { -- | D[let] will bind variables containing arrays in accumulator mode. + chcLetArrayAccum :: Bool + , -- | D[case] will bind variables containing arrays in accumulator mode. + chcCaseArrayAccum :: Bool + , -- | Introduce top-level arguments containing arrays in accumulator mode. + chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool + } + deriving (Show) + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig + { chcLetArrayAccum = False + , chcCaseArrayAccum = False + , chcArgArrayAccum = False + , chcSmartWith = False + } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True + , chcCaseArrayAccum = True + , chcArgArrayAccum = True + , chcSmartWith = True } + + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs new file mode 100644 index 0000000..019119c --- /dev/null +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE GADTs #-} +module CHAD.Drev.Types.ToTan where + +import Data.Bifunctor (bimap) + +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep + + +toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) +toTanE SNil SNil SNil = SNil +toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + +toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) +toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal + STEither t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" |
