diff options
| -rw-r--r-- | src/CHAD.hs | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index cb48816..cfae98d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} @@ -196,14 +197,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) -- incidentally also add a bunch of additional bindings, namely 'Reverse -- (TapeUnfoldings binds)', so the calling code just has to skip those in -- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) - -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = - let Reconstructor unf build = buildReconstructor binds - in (fst $ weakenBindings weakenExpr (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) ---------------------------------- DERIVATIVES --------------------------------- @@ -913,8 +914,8 @@ drev des accumMap sd = \case subOut (elet (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ - in letBinds rebinds $ + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ ELet ext (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ elet @@ -929,8 +930,8 @@ drev des accumMap sd = \case a2) $ EPair ext (sAB_A $ EFst ext (evar IZ)) (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) - (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ - in letBinds rebinds $ + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ ELet ext (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ elet @@ -1095,8 +1096,8 @@ drev des accumMap sd = \case -- the tape for this element ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ + let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) + in letBinds (rebinds IZ) $ weakenExpr (autoWeak (#d (auto1 @sdElt) &. #pro (d2ace envPro) &. #etape (subList (bindingsBinds e0) subtapeE) |
