diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/CHAD.hs | |
parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 83 |
1 files changed, 44 insertions, 39 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 1126fde..ac308ac 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -292,7 +292,7 @@ conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t)) + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) | Idx2Me (Idx (Select env sto "merge") t) | Idx2Di (Idx (Select env sto "discr") t) @@ -319,7 +319,7 @@ conv2Idx DTop i = case i of {} zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext -zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t) +zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) ------------------------------------ SUBENVS ----------------------------------- @@ -359,7 +359,7 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = ELet ext (weakenExpr WSink e2) $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext t + (EPlus ext (d2M t) (ESnd ext (EVar ext (typeOf e1) (IS IZ))) (ESnd ext (EVar ext (typeOf e2) IZ))) @@ -369,7 +369,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = ELet ext e $ let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t) +expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl @@ -425,11 +425,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of (SEYes accrevsub) (VarMap.sink1 accumMap) (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -453,7 +453,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of - Just i -> VarMap.insert i (STAccum t) IZ accumMap' + Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' Nothing -> accumMap') (\(shbinds :: SList _ shbinds) -> let shbindsC = slistMap (\_ -> Const ()) shbinds @@ -466,7 +466,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do @@ -493,6 +493,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STF64 -> False STBool -> True STAccum{} -> False + STLEither a b -> isDiscrete a && isDiscrete b ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- @@ -596,7 +597,7 @@ drev des accumMap = \case SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI))) + (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop @@ -666,7 +667,7 @@ drev des accumMap = \case subtape (EFst ext e1) sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $ + (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e @@ -676,7 +677,7 @@ drev des accumMap = \case subtape (ESnd ext e1) sub - (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) @@ -687,12 +688,11 @@ drev des accumMap = \case subtape (EInl ext (d1 t2) e1) sub - (EMaybe ext + (ELCase ext + (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) - (weakenExpr (WCopy (wSinks' @[_,_])) e2) + (weakenExpr (WCopy WSink) e2) (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) - (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> @@ -700,12 +700,11 @@ drev des accumMap = \case subtape (EInr ext (d1 t1) e1) sub - (EMaybe ext + (ELCase ext + (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) + (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") + (weakenExpr (WCopy WSink) e2)) ECase _ e (a :: Expr _ _ t) b | STEither t1 t2 <- typeOf e @@ -727,7 +726,7 @@ drev des accumMap = \case -> subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in + let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in Ret (e0 `BPush` (tPrimal, ECase ext e1 @@ -755,7 +754,7 @@ drev des accumMap = \case EPair ext (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (EInl ext (d2 t2) + (ELInl ext (d2 t2) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ in letBinds rebinds $ @@ -774,10 +773,10 @@ drev des accumMap = \case EPair ext (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (EInr ext (d2 t1) + (ELInr ext (d2 t1) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ ELet ext - (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $ + (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ plus_AB_E (EFst ext (EVar ext tCaseRet (IS IZ))) @@ -934,8 +933,8 @@ drev des accumMap = \case (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (ezeroD2 eltty) (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) @@ -975,6 +974,7 @@ drev des accumMap = \case <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n + , Refl <- lemZeroInfoD2 eltty , let tIxN = tTup (sreplicate n tIx) -> Ret (binds `BPush` (STArr n (d1 eltty), e1) `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) @@ -983,10 +983,11 @@ drev des accumMap = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ)))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ + (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) + (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) + (ENil ext)) + (EVar ext (d2 eltty) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e @@ -1026,6 +1027,10 @@ drev des accumMap = \case ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" + ELNil{} -> err_unsupported "ELNil" + ELInl{} -> err_unsupported "ELInl" + ELInr{} -> err_unsupported "ELInr" + ELCase{} -> err_unsupported "ELCase" EWith{} -> err_accum EAccum{} -> err_accum @@ -1059,7 +1064,7 @@ drev des accumMap = \case (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (EZero ext t))) $ + (ezeroD2 t))) $ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) (EVar ext (d2 at') IZ)) @@ -1091,36 +1096,36 @@ drevScoped des accumMap argty argsto argids expr = case argsto of | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> case sub of SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) + SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) SAccum | Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap - , Just Refl <- testEquality foundTy (STAccum argty) + , Just Refl <- testEquality foundTy (STAccum (d2M argty)) , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> RetScoped e0 subtape e1 sub $ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in - ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ + ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) + &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) -- Our contribution to the binding's cotangent _here_ is -- zero, because we're contributing to an earlier binding -- of the same value instead. - (EPair ext e2 (EZero ext argty)) + (EPair ext e2 (ezeroD2 argty)) | let accumMap' = case argids of - Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) + Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> RetScoped e0 subtape e1 sub $ - EWith ext argty (EZero ext argty) $ + EWith ext (d2M argty) (ezeroD2 argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) + &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) |