diff options
| -rw-r--r-- | src/AST/Bindings.hs | 9 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 4 | ||||
| -rw-r--r-- | src/CHAD.hs | 105 |
3 files changed, 62 insertions, 56 deletions
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 2310f4b..463586a 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -28,6 +28,10 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + mapBindings :: (forall env' t'. f env' t' -> g env' t') -> Bindings f env binds -> Bindings g env binds mapBindings _ BTop = BTop @@ -42,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) = let (b', w') = weakenBindings wf w b in (BPush b' (t, wf w' x), WCopy w') +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' weakenOver SNil w = w weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index dcaf82f..82ec1d6 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -89,10 +89,10 @@ splitLets' = \sub -> \case -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') + bs1 = fst (weakenBindingsE WSink bs1') (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) diff --git a/src/CHAD.hs b/src/CHAD.hs index 08c0a2f..222617b 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -147,7 +147,7 @@ growRecon t ts (Reconstructor unfbs bs) -- Add a 'fst' at the bottom of the builder stack. -- First we have to weaken most of 'bs' to skip one more binding in the -- unfolder stack above it. - (BPush (fst (weakenBindings weakenExpr + (BPush (fst (weakenBindingsE (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) (WSink :: env :> (Tape (t : ts) : env))) bs)) (t @@ -664,7 +664,7 @@ weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindings weakenExpr w binds + let (binds', _) = weakenBindingsE w binds in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. @@ -705,7 +705,7 @@ freezeRet :: Descr env sto -> Ret env sto (D2 t) t -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) library = #d (auto1 @(D2 t)) @@ -785,14 +785,14 @@ drev des accumMap sd = \case | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + Ret (bconcat (rhs0 `bpush` rhs1) body0') (subenvConcat subtapeRHS subtapeBody) (weakenExpr wbody0' body1) subBoth @@ -900,8 +900,8 @@ drev des accumMap sd = \case , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) @@ -911,11 +911,9 @@ drev des accumMap sd = \case -> subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> - Ret (e0 `BPush` - (tPrimal, - ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))) + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut @@ -976,7 +974,7 @@ drev des accumMap sd = \case (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) + Ret (e0 `bpush` e1) (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub @@ -984,15 +982,15 @@ drev des accumMap sd = \case (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - ECustom _ _ tb storety srce pr du a b + ECustom _ _ tb _ srce pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> case isDense (d2M (typeOf srce)) sd of Just Refl -> - Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) (SEYesR (SENo (SENo (SENo bsubtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub @@ -1000,9 +998,9 @@ drev des accumMap sd = \case weakenExpr (WCopy (WSink .> WSink)) b2) Nothing -> - Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))) + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) (SEYesR (SENo (SENo bsubtape))) (EFst ext (EVar ext (typeOf pr) IZ)) bsub @@ -1023,7 +1021,7 @@ drev des accumMap sd = \case (subenvAll (desD1E usedDes)) (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) (subenvCompose subMergeUsed' sub) - (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ weakenExpr (autoWeak (#d (auto1 @sd) &. #shbinds (bindingsBinds e0) @@ -1068,29 +1066,28 @@ drev des accumMap sd = \case let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in Ret (mergePrimalBindings - `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #propr :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) - w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) - in EPair ext (weakenExpr w e1) (collectexpr w'))) - `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) + `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) + `bpush` EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #propr :++: #d1env)) + e0)) $ + let w = autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) + w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) + in EPair ext (weakenExpr w e1) (collectexpr w')) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) @@ -1172,8 +1169,8 @@ drev des accumMap sd = \case | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) @@ -1191,9 +1188,9 @@ drev des accumMap sd = \case , let tIxN = tTup (sreplicate n tIx) -> sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))) + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) (SEYesR (SEYesR (SENo subtape))) (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) @@ -1225,8 +1222,8 @@ drev des accumMap sd = \case | SpArr sd' <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> - Ret (e0 `BPush` (STArr (SS n) t, e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub @@ -1274,8 +1271,8 @@ deriv_extremum extremum des accumMap sd e , let tIxN = tTup (sreplicate (SS n) tIx) = sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) (SEYesR (SEYesR subtape)) (EVar ext at' IZ) sub |
