diff options
| -rw-r--r-- | src/AST.hs | 21 | ||||
| -rw-r--r-- | src/CHAD.hs | 179 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 21 | 
3 files changed, 114 insertions, 107 deletions
| @@ -355,3 +355,24 @@ eidxEq (SS n) a b                                          (ESnd ext (EVar ext ty IZ))))          (eidxEq n (EFst ext (EVar ext ty (IS IZ)))                    (EFst ext (EVar ext ty IZ))) + +emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr = +  let STArr n t = typeOf arr +  in ELet ext arr $ +       EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ +         ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) +                            (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +           weakenExpr (WCopy (WSink .> WSink)) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip a b = +  let STArr n t1 = typeOf a +      STArr _ t2 = typeOf b +  in ELet ext a $ +     ELet ext (weakenExpr WSink b) $ +       EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ +         EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) +                             (EVar ext (tTup (sreplicate n tIx)) IZ)) +                   (EIdx ext (EVar ext (STArr n t2) (IS IZ)) +                             (EVar ext (tTup (sreplicate n tIx)) IZ)) diff --git a/src/CHAD.hs b/src/CHAD.hs index 1fd34d8..d45898a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -20,12 +20,9 @@  -- useful here.  {-# LANGUAGE PartialTypeSignatures #-}  {-# OPTIONS -Wno-partial-type-signatures #-} - --- TODO DO NOT COMMIT THIS -{-# OPTIONS -Wno-unused-top-binds #-}  module CHAD ( -  -- drev, -  -- freezeRet, +  drev, +  freezeRet,    Storage(..),    Descr(..),    Select, @@ -57,11 +54,14 @@ tapeTy :: SList STy binds -> STy (Tape binds)  tapeTy SNil = STNil  tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollect :: Bindings f env binds -> Append binds env :> env2 -> Ex env2 (Tape binds) -bindingsCollect BTop _ = ENil ext -bindingsCollect (BPush binds (t, _)) w = +bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds +                -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollect BTop SETop _ = ENil ext +bindingsCollect (BPush binds (t, _)) (SEYes sub) w =    EPair ext (EVar ext t (w @> IZ)) -            (bindingsCollect binds (w .> WSink)) +            (bindingsCollect binds sub (w .> WSink)) +bindingsCollect (BPush binds _) (SENo sub) w = +  bindingsCollect binds sub (w .> WSink)  -- In order from large to small: i.e. in reverse order from what we want,  -- because in a Bindings, the head of the list is the bottom-most entry. @@ -195,6 +195,7 @@ reconstructBindings binds tape =  --------------------------------- VECTORISATION --------------------------------  -- Currently only used in D[build1], should be removed. +{-  type family Vectorise n list where    Vectorise _ '[] = '[]    Vectorise n (t : ts) = TArr n t : Vectorise n ts @@ -239,6 +240,7 @@ vectorise1Binds env n (bs `BPush` (t, e)) =        e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n))                         (vectoriseExpr SNil (bindingsBinds bs) env e)    in bs' `BPush` (STArr (SS SZ) t, e') +-}  ---------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- @@ -824,18 +826,17 @@ drev des = \case                (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")                (weakenExpr (WCopy (wSinks' @[_,_])) e2))) -{-    ECase _ e (a :: Ex _ t) b      | STEither t1 t2 <- typeOf e -    , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e -    , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a -    , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b +    , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e +    , Ret (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a +    , Ret (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b      , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) -    , let tapeA = tapeTy (bindingsBinds a0) -    , let tapeB = tapeTy (bindingsBinds b0) -    , let collectA = bindingsCollect a0 -    , let collectB = bindingsCollect b0 +    , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) +    , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) +    , let collectA = bindingsCollect a0 subtapeA +    , let collectB = bindingsCollect b0 subtapeB      , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)      , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0      , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 @@ -850,42 +851,43 @@ drev des = \case              ECase ext e1                (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))                (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) +        (SEYes subtapeE)          (EFst ext (EVar ext tPrimal IZ))          subOut          (ELet ext             (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) -              (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ +              (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ                 in letBinds rebinds $                      ELet ext -                      (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ +                      (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $                      ELet ext                        (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                             &. #a0 (bindingsBinds a0) +                                             &. #ta0 (subList (bindingsBinds a0) subtapeA)                                               &. #prea0 prerebinds                                               &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) -                                             &. #binds (tPrimal `SCons` bindingsBinds e0) +                                             &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)                                               &. #tl (d2ace (select SAccum des))) -                                            (#d :++: #a0 :++: #tl) -                                            (#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) +                                            (#d :++: #ta0 :++: #tl) +                                            (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))                                    a2') $                      EPair ext                       (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $                          EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))                       (EInl ext (d2 t2)                         (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)))) -              (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ +              (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ                 in letBinds rebinds $                      ELet ext -                      (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ +                      (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $                      ELet ext                        (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                             &. #b0 (bindingsBinds b0) +                                             &. #tb0 (subList (bindingsBinds b0) subtapeB)                                               &. #preb0 prerebinds                                               &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) -                                             &. #binds (tPrimal `SCons` bindingsBinds e0) +                                             &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)                                               &. #tl (d2ace (select SAccum des))) -                                            (#d :++: #b0 :++: #tl) -                                            (#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) +                                            (#d :++: #tb0 :++: #tl) +                                            (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))                                    b2') $                      EPair ext                        (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ @@ -901,21 +903,24 @@ drev des = \case    EConst _ t val ->      Ret BTop +        SETop          (EConst ext t val)          (subenvNone (select SMerge des))          (ENil ext)    EOp _ op e -    | Ret e0 e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des e ->      case d2op op of        Linear d2opfun ->          Ret e0 +            subtape              (d1op op e1)              sub              (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))                 (weakenExpr (WCopy WSink) e2))        Nonlinear d2opfun ->          Ret (e0 `BPush` (d1 (typeOf e), e1)) +            (SEYes subtape)              (d1op op $ EVar ext (d1 (typeOf e)) IZ)              sub              (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) @@ -924,16 +929,20 @@ drev des = \case    EError t s ->      Ret BTop +        SETop          (EError (d1 t) s)          (subenvNone (select SMerge des))          (ENil ext)    EConstArr _ n t val ->      Ret BTop +        SETop          (EConstArr ext n t val)          (subenvNone (select SMerge des))          (ENil ext) +  EBuild1{} -> error "CHAD of EBuild1: Please use EBuild instead" +  {-    -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape    EBuild1 _ ne (orige :: Ex _ eltty)      | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne  -- allowed to ignore ne2 here because ne has a discrete result @@ -1010,9 +1019,11 @@ drev des = \case           ELet ext (ENil ext) $           ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ)))      }} +  -} +  -- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2    EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) -    | Ret (she0 :: Bindings _ _ she_binds) she1 _ _ <- drev des she  -- allowed to ignore she2 here because she has a discrete result +    | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she  -- allowed to ignore she2 here because she has a discrete result      , let eltty = typeOf orige      , shty :: STy shty <- tTup (sreplicate ndim tIx)      , Refl <- indexTupD1Id ndim -> @@ -1020,50 +1031,36 @@ drev des = \case      let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in      subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->      accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> -    case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> +    case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->      case assertSubenvEmpty sub of { Refl -> -    let tapety = tapeTy (bindingsBinds e0) in -    let collectexpr = bindingsCollect e0 in -    Ret (she0 `BPush` (shty, she1) -              `BPush` (STArr ndim tapety +    let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in +    let collectexpr = bindingsCollect e0 subtapeE in +    Ret (BTop `BPush` (shty, letBinds she0 she1) +              `BPush` (STArr ndim (STPair (d1 eltty) tapety)                        ,EBuild ext ndim                           (EVar ext shty IZ)                           (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)                                                                                &. #sh (shty `SCons` SNil) -                                                                              &. #she0 (bindingsBinds she0)                                                                                &. #d1env (sD1eEnv des)                                                                                &. #d1env' (sD1eEnv usedDes))                                                                               (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                                                             (#ix :++: #sh :++: #she0 :++: #d1env)) +                                                                             (#ix :++: #sh :++: #d1env))                                                                     e0)) $ -                            collectexpr (autoWeak (#ix (shty `SCons` SNil) +                            let w = autoWeak (#ix (shty `SCons` SNil)                                                     &. #sh (shty `SCons` SNil) -                                                   &. #she0 (bindingsBinds she0)                                                     &. #e0 (bindingsBinds e0)                                                     &. #d1env (sD1eEnv des)                                                     &. #d1env' (sD1eEnv usedDes))                                                    (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                                  (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env))))) -        (EBuild ext ndim -           (EVar ext shty (IS IZ)) -           (ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (IS IZ)) -                               (EVar ext shty IZ)) $ -            let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ -            in letBinds rebinds $ -                 weakenExpr (autoWeak (#ix (shty `SCons` SNil) -                                       &. #sh (shty `SCons` SNil) -                                       &. #she0 (bindingsBinds she0) -                                       &. #e0 (bindingsBinds e0) -                                       &. #tape (tapety `SCons` SNil) -                                       &. #tapearr (STArr ndim tapety `SCons` SNil) -                                       &. #prerebinds prerebinds -                                       &. #d1env (sD1eEnv des) -                                       &. #d1env' (sD1eEnv usedDes)) -                                      (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                      ((#e0 :++: #prerebinds) :++: #tape :++: #ix :++: #tapearr :++: #sh :++: #she0 :++: #d1env)) -                            e1)) +                                                  (#e0 :++: #ix :++: #sh :++: #d1env) +                            in EPair ext (weakenExpr w e1) (collectexpr w))) +              `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) +                                               (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) +        (SEYes (SENo (SEYes SETop))) +        (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) +              (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))          (subenvCompose subMergeUsed proSub) -        (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_binds) : Tup (Replicate ndim TIx) : Append she_binds (D2AcE (Select env sto "accum"))) (d2ace envPro) in +        (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in           ESnd ext $            uninvertTup (d2e envPro) (STArr ndim STNil) $              makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ @@ -1074,29 +1071,29 @@ drev des = \case                  -- the tape for this element                  ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))                                     (EVar ext shty (IS IZ))) $ -                let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ +                let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ                  in letBinds rebinds $                       weakenExpr (autoWeak (#d (auto1 @(D2 eltty))                                             &. #pro (d2ace envPro) -                                           &. #ebinds (bindingsBinds e0) +                                           &. #etape (subList (bindingsBinds e0) subtapeE)                                             &. #prerebinds prerebinds                                             &. #tape (tapety `SCons` SNil)                                             &. #ix (shty `SCons` SNil)                                             &. #darr (STArr ndim (d2 eltty) `SCons` SNil)                                             &. #tapearr (STArr ndim tapety `SCons` SNil)                                             &. #sh (shty `SCons` SNil) -                                           &. #shebinds (bindingsBinds she0)                                             &. #d2acUsed (d2ace (select SAccum usedDes))                                             &. #d2acEnv (d2ace (select SAccum des))) -                                          (#pro :++: #d :++: #ebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) -                                          ((#ebinds :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #shebinds :++: #d2acEnv) -                                 .> wPro (bindingsBinds e0)) +                                          (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) +                                          ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) +                                 .> wPro (subList (bindingsBinds e0) subtapeE))                                  e2)      }}    EUnit _ e -    | Ret e0 e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des e ->      Ret e0 +        subtape          (EUnit ext e1)          sub          (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ @@ -1104,43 +1101,50 @@ drev des = \case    EReplicate1Inner _ en e      -- We're allowed to ignore en2 here because the output of 'ei' is discrete. -    | Rets binds (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) +    | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)          <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil      , let STArr ndim eltty = typeOf e -> -    Ret (binds `BPush` (d1 (typeOf e), e1)) -        (weakenExpr WSink $ EReplicate1Inner ext en1 e1) +    Ret binds +        subtape +        (EReplicate1Inner ext en1 e1)          sub          (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))                                     (EZero eltty)                                     (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ -          weakenExpr (WCopy (WSink .> WSink)) e2) +          weakenExpr (WCopy WSink) e2)    EIdx0 _ e -    | Ret e0 e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des e      , STArr _ t <- typeOf e ->      Ret e0 +        subtape          (EIdx0 ext e1)          sub          (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $           weakenExpr (WCopy WSink) e2) +  EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" +  {-    EIdx1 _ e ei      -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. -    | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) +    | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)          <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil      , STArr (SS n) eltty <- typeOf e -> -    Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)) -        (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) IZ) -                   (weakenExpr WSink ei1)) +    Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) +               `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) +        (SEYes (SENo subtape)) +        (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) +                   (weakenExpr (WSink .> WSink) ei1))          sub -        (ELet ext (ebuildUp1 n (EFst ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) -                               (ESnd ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) +        (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) +                               (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))                                 (EVar ext (STArr n (d2 eltty)) (IS IZ))) $           weakenExpr (WCopy (WSink .> WSink)) e2) +  -}    EIdx _ e ei      -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. -    | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) +    | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)          <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil      , STArr n eltty <- typeOf e      , Refl <- indexTupD1Id n @@ -1148,6 +1152,7 @@ drev des = \case      Ret (binds `BPush` (STArr n (d1 eltty), e1)                 `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))                 `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) +        (SEYes (SEYes (SENo subtape)))          (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))                    (EVar ext (tTup (sreplicate n tIx)) IZ))          sub @@ -1155,27 +1160,30 @@ drev des = \case                       ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ)))))                         (EVar ext (d2 eltty) (IS (IS IZ)))                         (EZero eltty)) $ -         weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) +         weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)    EShape _ e      -- Allowed to ignore e2 here because the output of EShape is discrete,      -- hence we'd be passing a zero cotangent to e2 anyway. -    | Ret e0 e1 _ _ <- drev des e +    | Ret e0 subtape e1 _ _ <- drev des e      , STArr n _ <- typeOf e      , Refl <- indexTupD1Id n ->      Ret e0 +        subtape          (EShape ext e1)          (subenvNone (select SMerge des))          (ENil ext)    ESum1Inner _ e -    | Ret e0 e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des e      , STArr (SS n) t <- typeOf e -> -    Ret (e0 `BPush` (STArr (SS n) t, e1)) -        (ESum1Inner ext (EVar ext (STArr (SS n) t) IZ)) +    Ret (e0 `BPush` (STArr (SS n) t, e1) +            `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) +        (SEYes (SENo subtape)) +        (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))          sub          (ELet ext (EReplicate1Inner ext -                     (ESnd ext (EShape ext (EVar ext (STArr (SS n) t) (IS IZ)))) +                     (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))                       (EVar ext (STArr n (d2 t)) IZ)) $           weakenExpr (WCopy (WSink .> WSink)) e2) @@ -1190,7 +1198,6 @@ drev des = \case    EAccum{} -> err_accum    EZero{} -> err_monoid    EPlus{} -> err_monoid --}    where      err_accum = error "Accumulator operations unsupported in the source program" diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 4f84e8d..beb93da 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -189,24 +189,3 @@ dfwdDN = \case    where      err_accum = error "Accumulator operations unsupported in the source program"      err_monoid = error "Monoid operations unsupported in the source program" - -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = -  let STArr n t = typeOf arr -  in ELet ext arr $ -       EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ -         ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) -                            (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -           weakenExpr (WCopy (WSink .> WSink)) f - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip a b = -  let STArr n t1 = typeOf a -      STArr _ t2 = typeOf b -  in ELet ext a $ -     ELet ext (weakenExpr WSink b) $ -       EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ -         EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) -                             (EVar ext (tTup (sreplicate n tIx)) IZ)) -                   (EIdx ext (EVar ext (STArr n t2) (IS IZ)) -                             (EVar ext (tTup (sreplicate n tIx)) IZ)) | 
