diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-06-09 23:07:36 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-09 23:07:36 +0200 | 
| commit | eed0f2999d6f6c8485ef53deb38f9d0a67b4f88e (patch) | |
| tree | 4ae9da61d463650e5916908c45ded0e9132eb5de | |
| parent | 514c4bb0bfe908ec39ab4fa09dbf51bf7db29bd4 (diff) | |
WIP
| -rw-r--r-- | src/AST.hs | 43 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 35 | ||||
| -rw-r--r-- | src/CHAD.hs | 141 | 
3 files changed, 126 insertions, 93 deletions
| @@ -461,27 +461,30 @@ eidxEq (SS n) a b          (eidxEq n (EFst ext (EVar ext ty (IS IZ)))                    (EFst ext (EVar ext ty IZ))) -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = -  let STArr n t = typeOf arr -  in ELet ext arr $ -       EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ -         ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) -                            (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -           weakenExpr (WCopy (WSink .> WSink)) f +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr +  | STArr n t <- typeOf arr +  , Dict <- styKnown t +  = ELet ext arr $ +      EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ +        ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) +                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +          weakenExpr (WCopy (WSink .> WSink)) f -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = -  let STArr n t1 = typeOf arr1 -      STArr _ t2 = typeOf arr2 -  in ELet ext arr1 $ -     ELet ext (weakenExpr WSink arr2) $ -       EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ -         ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) -                            (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -         ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) -                            (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ -           weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 +  | STArr n t1 <- typeOf arr1 +  , STArr _ t2 <- typeOf arr2 +  , Dict <- styKnown t1 +  , Dict <- styKnown t2 +  = ELet ext arr1 $ +    ELet ext (weakenExpr WSink arr2) $ +      EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ +        ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) +                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +        ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) +                           (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ +          weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f  ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))  ezip arr1 arr2 = diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index ddae7fe..369d395 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -111,6 +111,37 @@ isAbsent (SpMaybe s) = isAbsent s  isAbsent (SpArr s) = isAbsent s  isAbsent SpScal = False +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent _ _ = ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2  -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = +  eunPair e1 $ \w1 e1a e1b -> +  eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> +    EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) +              (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = +  elet e2 $ +    elcase (weakenExpr WSink e1) +      (evar IZ) +      (elcase (evar (IS IZ)) +        (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) +        (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) +        (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) +      (elcase (evar (IS IZ)) +        (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) +        (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") +        (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = +  elet e2 $ +    emaybe (weakenExpr WSink e1) +      (evar IZ) +      (emaybe (evar (IS IZ)) +        (EJust ext (evar IZ)) +        (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 +  data SBool b where    SF :: SBool False @@ -120,7 +151,7 @@ deriving instance Show (SBool b)  data Injection sp a b where    -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that    -- 'sparsePlusS' can provide injections even if the caller doesn't require -  -- them. This eliminates pointless checks. +  -- them. This simplifies the sparsePlusS code.    Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b    Noinj :: Injection False a b @@ -138,7 +169,7 @@ withInj2 Noinj _ _ = Noinj  withInj2 _ Noinj _ = Noinj  -- | This function produces quadratically-sized code in the presence of nested --- dynamic sparsity. しょうがない。 +-- dynamic sparsity. TODO can this be improved?  sparsePlusS    :: SBool inj1 -> SBool inj2    -> SMTy t -> Sparse t t1 -> Sparse t t2 diff --git a/src/CHAD.hs b/src/CHAD.hs index 241825e..7cd4c26 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1094,42 +1094,39 @@ drev des accumMap sd = \case      case lemAppendNil @e_binds of { Refl ->      let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in      let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in -    Ret (BTop `BPush` (shty, drevPrimal des she) -              `BPush` (STArr ndim (STPair (d1 eltty) tapety) -                      ,EBuild ext ndim -                         (EVar ext shty IZ) -                         (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) -                                                                              &. #sh (shty `SCons` SNil) -                                                                              &. #d1env (desD1E des) -                                                                              &. #d1env' (desD1E usedDes)) -                                                                             (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                                                             (#ix :++: #sh :++: #d1env)) -                                                                   e0)) $ -                            let w = autoWeak (#ix (shty `SCons` SNil) -                                              &. #sh (shty `SCons` SNil) -                                              &. #e0 (bindingsBinds e0) -                                              &. #d1env (desD1E des) -                                              &. #d1env' (desD1E usedDes)) -                                             (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                             (#e0 :++: #ix :++: #sh :++: #d1env) -                                w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) -                            in EPair ext (weakenExpr w e1) (collectexpr w'))) -              `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) -                                               (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) -        (SEYesR (SENo (SEYesR SETop))) -        (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) -              (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) +    let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in +    let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in +    Ret (mergePrimalBindings +          `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)) +          `BPush` (STArr ndim (STPair (d1 eltty) tapety) +                  ,EBuild ext ndim +                     (EVar ext shty IZ) +                     (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) +                                                                          &. #sh (shty `SCons` SNil) +                                                                          &. #propr (d1e envPro) +                                                                          &. #d1env (desD1E des) +                                                                          &. #d1env' (desD1E usedDes)) +                                                                         (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                                                         (#ix :++: #sh :++: #propr :++: #d1env)) +                                                               e0)) $ +                        let w = autoWeak (#ix (shty `SCons` SNil) +                                          &. #sh (shty `SCons` SNil) +                                          &. #e0 (bindingsBinds e0) +                                          &. #propr (d1e envPro) +                                          &. #d1env (desD1E des) +                                          &. #d1env' (desD1E usedDes)) +                                         (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                         (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) +                            w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) +                        in EPair ext (weakenExpr w e1) (collectexpr w'))) +          `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) +        (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) +        (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))          (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) -        (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in +        (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in           ESnd ext $             uninvertTup (d2e envPro) (STArr ndim STNil) $ -             -- TODO: what's happening here is that because of the sparsity -             -- rewrite, makeAccumulators needs primals where it previously -             -- didn't. The build derivative is currently not saving those -             -- primals, so the hole below cannot currently be filled. The -             -- appropriate primals (waves hands) need to be stored, so that a -             -- weakening can be provided here. -             makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $ +             makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $                 EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $                   -- the cotangent for this element                   ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) @@ -1148,10 +1145,11 @@ drev des accumMap sd = \case                                              &. #darr (auto1 @(TArr ndim sdElt))                                              &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))                                              &. #sh (auto1 @shty) +                                            &. #propr (d1e envPro)                                              &. #d2acUsed (d2ace (select SAccum usedDes))                                              &. #d2acEnv (d2ace (select SAccum des)))                                             (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) -                                           ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) +                                           ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv)                                    .> wPro (subList (bindingsBinds e0) subtapeE))                                   e2)      }}} @@ -1167,32 +1165,34 @@ drev des accumMap sd = \case             weakenExpr (WCopy WSink) e2)    EReplicate1Inner _ en e -    -- We're allowed to ignore en2 here because the output of 'ei' is discrete. -    | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) -        <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil +    -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. +    | SpArr sdElt <- sd      , let STArr ndim eltty = typeOf e -> -    Ret binds -        subtape -        (EReplicate1Inner ext en1 e1) -        sub -        (EMaybe ext -          (zeroTup (subList (select SMerge des) sub)) -          (ELet ext (EJust ext (EFold1Inner ext Commut -                        (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) -                        (ezeroD2 eltty) -                        (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ -            weakenExpr (WCopy (WSink .> WSink)) e2) -          (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) +    -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. +    sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> +    case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> +      Ret binds +          subtape +          (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) +          sub +          (ELet ext (EFold1Inner ext Commut +                         (sparsePlus (d2M eltty) sdElt' +                            (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) +                            (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) +                         (inj2 (ENil ext)) +                         (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ +             weakenExpr (WCopy WSink) e2) +    }    EIdx0 _ e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e +    | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e      , STArr _ t <- typeOf e ->      Ret e0          subtape          (EIdx0 ext e1)          sub -        (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ -         weakenExpr (WCopy WSink) e2) +        (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ +           weakenExpr (WCopy WSink) e2)    EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"    {- @@ -1214,26 +1214,25 @@ drev des accumMap sd = \case    -}    EIdx _ e ei -    -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. -    | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) -        <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil -    , STArr n eltty <- typeOf e +    -- We're allowed to differentiate ei as primal because its output is discrete. +    | STArr n eltty <- typeOf e      , Refl <- indexTupD1Id n -    , Refl <- lemZeroInfoD2 eltty -    , let tIxN = tTup (sreplicate n tIx)  -> -    Ret (binds `BPush` (STArr n (d1 eltty), e1) -               `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) -               `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) -        (SEYesR (SEYesR (SENo subtape))) -        (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) -                  (EVar ext (tTup (sreplicate n tIx)) IZ)) -        sub -        (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) -                             (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) -                                                   (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) -                                        (ENil ext)) -                             (EVar ext (d2 eltty) IZ)) $ -         weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +    , let tIxN = tTup (sreplicate n tIx) -> +    sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> +    case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> +      Ret (binds `BPush` (STArr n (d1 eltty), e1) +                 `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) +                 `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))) +          (SEYesR (SEYesR (SENo subtape))) +          (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) +                    (EVar ext (tTup (sreplicate n tIx)) IZ)) +          sub +          (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere) +                                 (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) +                                            (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext)) +                                 (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ +           weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +    }    EShape _ e      -- Allowed to ignore e2 here because the output of EShape is discrete, | 
