aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/CHAD.hs27
1 files changed, 14 insertions, 13 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index cb48816..cfae98d 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -196,14 +197,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)
-- incidentally also add a bunch of additional bindings, namely 'Reverse
-- (TapeUnfoldings binds)', so the calling code just has to skip those in
-- whatever it wants to do.
-reconstructBindings :: SList STy binds -> Idx env (Tape binds)
- -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
+reconstructBindings :: SList STy binds
+ -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
,SList STy (Reverse (TapeUnfoldings binds)))
-reconstructBindings binds tape =
- let Reconstructor unf build = buildReconstructor binds
- in (fst $ weakenBindings weakenExpr (WIdx tape)
- (bconcat (mapBindings fromUnfExpr unf) build)
- ,sreverse (stapeUnfoldings binds))
+reconstructBindings binds =
+ (\tape -> let Reconstructor unf build = buildReconstructor binds
+ in fst $ weakenBindingsE (WIdx tape)
+ (bconcat (mapBindings fromUnfExpr unf) build)
+ ,sreverse (stapeUnfoldings binds))
---------------------------------- DERIVATIVES ---------------------------------
@@ -913,8 +914,8 @@ drev des accumMap sd = \case
subOut
(elet
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
- (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ
- in letBinds rebinds $
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA
+ in letBinds (rebinds IZ) $
ELet ext
(EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
elet
@@ -929,8 +930,8 @@ drev des accumMap sd = \case
a2) $
EPair ext (sAB_A $ EFst ext (evar IZ))
(ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
- (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ
- in letBinds rebinds $
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB
+ in letBinds (rebinds IZ) $
ELet ext
(EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
elet
@@ -1095,8 +1096,8 @@ drev des accumMap sd = \case
-- the tape for this element
ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
(EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE)
+ in letBinds (rebinds IZ) $
weakenExpr (autoWeak (#d (auto1 @sdElt)
&. #pro (d2ace envPro)
&. #etape (subList (bindingsBinds e0) subtapeE)