summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs41
1 files changed, 32 insertions, 9 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 3a7b907..df792ce 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -62,14 +62,14 @@ tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
-bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds
+bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds
-> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-bindingsCollect BTop SETop _ = ENil ext
-bindingsCollect (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape BTop SETop _ = ENil ext
+bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w =
EPair ext (EVar ext t (w @> IZ))
- (bindingsCollect binds sub (w .> WSink))
-bindingsCollect (BPush binds _) (SENo sub) w =
- bindingsCollect binds sub (w .> WSink)
+ (bindingsCollectTape binds sub (w .> WSink))
+bindingsCollectTape (BPush binds _) (SENo sub) w =
+ bindingsCollectTape binds sub (w .> WSink)
-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
@@ -718,8 +718,8 @@ drev des accumMap = \case
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
, let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB)
- , let collectA = bindingsCollect a0 subtapeA
- , let collectB = bindingsCollect b0 subtapeB
+ , let collectA = bindingsCollectTape a0 subtapeA
+ , let collectB = bindingsCollectTape b0 subtapeB
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
, let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
@@ -822,6 +822,29 @@ drev des accumMap = \case
(ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
weakenExpr (WCopy (WSink .> WSink)) b2)
+ -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal
+ ERecompute _ e ->
+ deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
+ let smallE = unsafeWeakenWithSubenv usedSub e in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 ->
+ Ret (collectBindings (desD1E des) subD1eUsed)
+ (subenvAll (desD1E usedDes))
+ (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1)
+ (subenvCompose subMergeUsed sub)
+ (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
+ weakenExpr
+ (autoWeak (#d (auto1 @(D2 t))
+ &. #shbinds (bindingsBinds e0)
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #d1env (desD1E usedDes)
+ &. #tl' (d2ace (select SAccum usedDes))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
+ (#shbinds :++: #d :++: #d1env :++: #tl))
+ e2)
+ }
+
EError _ t s ->
Ret BTop
SETop
@@ -849,7 +872,7 @@ drev des accumMap = \case
case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollect e0 subtapeE in
+ let collectexpr = bindingsCollectTape e0 subtapeE in
Ret (BTop `BPush` (shty, letBinds she0 she1)
`BPush` (STArr ndim (STPair (d1 eltty) tapety)
,EBuild ext ndim