diff options
| -rw-r--r-- | src/CHAD.hs | 27 | 
1 files changed, 14 insertions, 13 deletions
| diff --git a/src/CHAD.hs b/src/CHAD.hs index cb48816..cfae98d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -3,6 +3,7 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImpredicativeTypes #-}  {-# LANGUAGE OverloadedLabels #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -196,14 +197,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)  -- incidentally also add a bunch of additional bindings, namely 'Reverse  -- (TapeUnfoldings binds)', so the calling code just has to skip those in  -- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) -                    -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) +reconstructBindings :: SList STy binds +                    -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))                         ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = -  let Reconstructor unf build = buildReconstructor binds -  in (fst $ weakenBindings weakenExpr (WIdx tape) -             (bconcat (mapBindings fromUnfExpr unf) build) -     ,sreverse (stapeUnfoldings binds)) +reconstructBindings binds = +  (\tape -> let Reconstructor unf build = buildReconstructor binds +            in fst $ weakenBindingsE (WIdx tape) +                      (bconcat (mapBindings fromUnfExpr unf) build) +  ,sreverse (stapeUnfoldings binds))  ---------------------------------- DERIVATIVES --------------------------------- @@ -913,8 +914,8 @@ drev des accumMap sd = \case          subOut          (elet             (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) -              (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ -               in letBinds rebinds $ +              (let (rebinds, prerebinds) = reconstructBindings subtapeListA +               in letBinds (rebinds IZ) $                      ELet ext                        (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $                      elet @@ -929,8 +930,8 @@ drev des accumMap sd = \case                                    a2) $                      EPair ext (sAB_A $ EFst ext (evar IZ))                                (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) -              (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ -               in letBinds rebinds $ +              (let (rebinds, prerebinds) = reconstructBindings subtapeListB +               in letBinds (rebinds IZ) $                      ELet ext                        (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $                      elet @@ -1095,8 +1096,8 @@ drev des accumMap sd = \case                   -- the tape for this element                   ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))                                      (EVar ext shty (IS IZ))) $ -                 let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ -                 in letBinds rebinds $ +                 let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) +                 in letBinds (rebinds IZ) $                        weakenExpr (autoWeak (#d (auto1 @sdElt)                                              &. #pro (d2ace envPro)                                              &. #etape (subList (bindingsBinds e0) subtapeE) | 
