diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-03 23:10:08 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-03 23:10:23 +0100 | 
| commit | 3b60a8609649019ba5bce053cdf266b4e3a51dfa (patch) | |
| tree | 009ce6e2be7db7b45feac364ba7a35a96f41e5f0 | |
| parent | 81d88dbc430ca6ec8390636f8b7162887b390873 (diff) | |
| -rw-r--r-- | src/CHAD.hs | 181 | 
1 files changed, 117 insertions, 64 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 67ffe12..72ce36d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1048,73 +1048,44 @@ drev des accumMap sd = \case      , let eltty = typeOf orige      , shty :: STy shty <- tTup (sreplicate ndim tIx)      , Refl <- indexTupD1Id ndim -> -    deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') -> -    let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in -    subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> -    accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> -    let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in -    case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 -> -    case lemAppendNil @e_binds of { Refl -> -    let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in -    let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in -    let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in -    let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in -    Ret (mergePrimalBindings -          `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) +    drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> +    let library = #ix (shty `SCons` SNil) +                  &. #e0 (bindingsBinds e0) +                  &. #propr (d1e provars) +                  &. #d1env (desD1E des) +                  &. #d (auto1 @sdElt) +                  &. #tape (auto1 @e_tape) +                  &. #pro (d2ace provars) +                  &. #d2acEnv (d2ace (select SAccum des)) +                  &. #darr (auto1 @(TArr ndim sdElt)) +                  &. #tapearr (auto1 @(TArr ndim e_tape)) in +    Ret (proPrimalBinds            `bpush` EBuild ext ndim -                    (EVar ext shty IZ) -                    (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) -                                                               &. #sh (shty `SCons` SNil) -                                                               &. #propr (d1e envPro) -                                                               &. #d1env (desD1E des) -                                                               &. #d1env' (desD1E usedDes)) -                                                              (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                                              (#ix :++: #sh :++: #propr :++: #d1env)) +                    (weakenExpr (wSinks (d1e provars)) (drevPrimal des she)) +                    (letBinds (fst (weakenBindingsE (autoWeak library +                                                              (#ix :++: #d1env) +                                                              (#ix :++: #propr :++: #d1env))                                                      e0)) $ -                       let w = autoWeak (#ix (shty `SCons` SNil) -                                         &. #sh (shty `SCons` SNil) -                                         &. #e0 (bindingsBinds e0) -                                         &. #propr (d1e envPro) -                                         &. #d1env (desD1E des) -                                         &. #d1env' (desD1E usedDes)) -                                        (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                        (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) -                           w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) -                       in EPair ext (weakenExpr w e1) (collectexpr w')) -          `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) -        (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) -        (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) -        (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) -        (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in +                      weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env) +                                                   (#e0 :++: #ix :++: #propr :++: #d1env)) +                                 (EPair ext e1 e1tape)) +          `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) +        (SEYesR (SENo (subenvAll (d1e provars)))) +        (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) +        (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) +        (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in           ESnd ext $ -           uninvertTup (d2e envPro) (STArr ndim STNil) $ -             makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ -               EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -                 -- the cotangent for this element -                 ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) -                                    (EVar ext shty IZ)) $ -                 -- the tape for this element -                 ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) -                                    (EVar ext shty (IS IZ))) $ -                 let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) -                 in letBinds (rebinds IZ) $ -                      weakenExpr (autoWeak (#d (auto1 @sdElt) -                                            &. #pro (d2ace envPro) -                                            &. #etape (subList (bindingsBinds e0) subtapeE) -                                            &. #prerebinds prerebinds -                                            &. #tape (auto1 @(Tape e_tape)) -                                            &. #ix (auto1 @shty) -                                            &. #darr (auto1 @(TArr ndim sdElt)) -                                            &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) -                                            &. #sh (auto1 @shty) -                                            &. #propr (d1e envPro) -                                            &. #d2acUsed (d2ace (select SAccum usedDes)) -                                            &. #d2acEnv (d2ace (select SAccum des))) -                                           (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) -                                           ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) -                                  .> wPro (subList (bindingsBinds e0) subtapeE)) -                                 e2) -    }} +           wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ +             EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ +               -- the tape for this element +               ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ)) +                                  (EVar ext shty IZ)) $ +               -- the cotangent for this element +               ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ)) +                                  (EVar ext shty (IS IZ))) $ +                weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv) +                                             (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) +                           e2)    EMap{} -> undefined @@ -1505,6 +1476,88 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of      , Refl <- lemAppendNil @tapebinds ->          RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) +           => Descr env sto +           -> VarMap Int (D2AcE (Select env sto "accum")) +           -> (STy a, Storage s) +           -> Sparse (D2 t) dt +           -> Expr ValId (a : env) t +           -> (forall provars shbinds tape d2a'. +                  SList STy provars +               -> Subenv (D2E (Select env sto "merge")) (D2E provars) +               -> Bindings Ex (D1E env) (D1E provars)  -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) +               -> Bindings Ex (D1 a : D1E env) shbinds +               -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) +               -> Ex (Append shbinds (D1 a : D1E env)) tape +               -> Sparse (D2 a) d2a' +               -> (forall env' b. +                      D1E provars :> env' +                   -> Ex (Append (D2AcE provars) env') b +                   -> Ex (                       env') (TPair b (Tup (D2E provars)))) +               -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' +               -> r) +           -> r +drevLambda des accumMap (argty, argsto) sd origef k = +  let t = typeOf origef in +  deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> +  let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in +  subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> +  accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> +  let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in +  let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in +  let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in +  case prf1 prodes argty argsto of { Refl -> +  case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> +  let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in +  extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> +  let library = #fbinds (bindingsBinds ef0) +                &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) +                &. #ftape (auto1 @(Tape e_tape)) +                &. #arg (d1 argty `SCons` SNil) +                &. #d (applySparse sd (d2 t) `SCons` SNil) +                &. #d1env (desD1E des) +                &. #d1env' (desD1E usedDes) +                &. #propr (d1e envPro) +                &. #d2acUsed (d2ace (select SAccum usedDes)) +                &. #d2acEnv (d2ace (select SAccum des)) +                &. #d2acPro (d2ace envPro) +                &. #efPrerebinds efPrerebinds in +  k envPro +    (subenvD2E (subenvCompose subMergeUsed proSub)) +    mergePrimalBindings +    (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) +    (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                  (#fbinds :++: #arg :++: #d1env)) +                ef1) +    (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) +    argSp +    (\wpro1 body -> +      uninvertTup (d2e envPro) (typeOf body) $ +        makeAccumulators wpro1 envPro $ +          body) +    (letBinds (efRebinds (IS IZ)) $ +      weakenExpr +        (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) +                          ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv) +         .> wPro (subList (bindingsBinds ef0) subtapeEf)) +        (getSparseArg ef2)) +  }} +  where +    extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) +                   => proxy env sto -> proxy2 a -> Storage s +                      -- if s == "merge", this simplifies to SubenvS '[D2 a] t' +                      -- if s == "discr", this simplifies to SubenvS '[] t' +                   -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' +                   -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r +    extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id +    extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) +    extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + +    prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s +         -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" +    prf1 _ _ SMerge = Refl +    prf1 _ _ SDiscr = Refl +  -- TODO: proper primal-only transform that doesn't depend on D1 = Id  drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)  drevPrimal des e  | 
