diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 41 |
1 files changed, 32 insertions, 9 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 3a7b907..df792ce 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -62,14 +62,14 @@ tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds +bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape BTop SETop _ = ENil ext +bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = EPair ext (EVar ext t (w @> IZ)) - (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = - bindingsCollect binds sub (w .> WSink) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (BPush binds _) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) -- In order from large to small: i.e. in reverse order from what we want, -- because in a Bindings, the head of the list is the bottom-most entry. @@ -718,8 +718,8 @@ drev des accumMap = \case , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollect a0 subtapeA - , let collectB = bindingsCollect b0 subtapeB + , let collectA = bindingsCollectTape a0 subtapeA + , let collectB = bindingsCollectTape b0 subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 @@ -822,6 +822,29 @@ drev des accumMap = \case (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ weakenExpr (WCopy (WSink .> WSink)) b2) + -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 -> + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1) + (subenvCompose subMergeUsed sub) + (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @(D2 t)) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } + EError _ t s -> Ret BTop SETop @@ -849,7 +872,7 @@ drev des accumMap = \case case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollect e0 subtapeE in + let collectexpr = bindingsCollectTape e0 subtapeE in Ret (BTop `BPush` (shty, letBinds she0 she1) `BPush` (STArr ndim (STPair (d1 eltty) tapety) ,EBuild ext ndim |