diff options
| -rw-r--r-- | src/AST.hs | 6 | ||||
| -rw-r--r-- | src/AST/Count.hs | 12 | ||||
| -rw-r--r-- | src/AST/Env.hs | 20 | ||||
| -rw-r--r-- | src/CHAD.hs | 516 | 
4 files changed, 313 insertions, 241 deletions
| @@ -20,7 +20,6 @@ import Data.Functor.Const  import Data.Kind (Type)  import Array -import AST.Env  import AST.Types  import AST.Weaken  import CHAD.Types @@ -289,11 +288,6 @@ subst' f w = \case  weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t  weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -wUndoSubenv :: Subenv env env' -> env' :> env -wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) -wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub -  slistIdx :: SList f list -> Idx list t -> f t  slistIdx (SCons x _) IZ = x  slistIdx (SCons _ list) (IS i) = slistIdx list i diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 31720a5..a928743 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -150,12 +150,12 @@ deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =  unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t  unsafeWeakenWithSubenv = \sub -> -  subst (\x t i -> case sinkWithSubenv i sub of +  subst (\x t i -> case sinkViaSubenv i sub of                       Just i' -> EVar x t i'                       Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")    where -    sinkWithSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) -    sinkWithSubenv IZ (SEYes _) = Just IZ -    sinkWithSubenv IZ (SENo _) = Nothing -    sinkWithSubenv (IS i) (SEYes sub) = IS <$> sinkWithSubenv i sub -    sinkWithSubenv (IS i) (SENo sub) = sinkWithSubenv i sub +    sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) +    sinkViaSubenv IZ (SEYes _) = Just IZ +    sinkViaSubenv IZ (SENo _) = Nothing +    sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub +    sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index c33bad3..4f34166 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,5 +1,6 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE ExplicitForAll #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE StandaloneDeriving #-} @@ -14,8 +15,8 @@ import Data  -- @env'@ ('SEYes') or not included in @env'@ ('SENo').  data Subenv env env' where    SETop :: Subenv '[] '[] -  SEYes :: Subenv env env' -> Subenv (t : env) (t : env') -  SENo  :: Subenv env env' -> Subenv (t : env) env' +  SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') +  SENo  :: forall t env env'. Subenv env env' -> Subenv (t : env) env'  deriving instance Show (Subenv env env')  subList :: SList f env -> Subenv env env' -> SList f env' @@ -41,3 +42,18 @@ subenvCompose SETop SETop = SETop  subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)  subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)  subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) + +subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat sub1 SETop = sub1 +subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) + +sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +sinkWithSubenv SETop = WId +sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SENo sub) = sinkWithSubenv sub + +wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv SETop = WId +wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub diff --git a/src/CHAD.hs b/src/CHAD.hs index 8319080..1fd34d8 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -20,9 +20,12 @@  -- useful here.  {-# LANGUAGE PartialTypeSignatures #-}  {-# OPTIONS -Wno-partial-type-signatures #-} + +-- TODO DO NOT COMMIT THIS +{-# OPTIONS -Wno-unused-top-binds #-}  module CHAD ( -  drev, -  freezeRet, +  -- drev, +  -- freezeRet,    Storage(..),    Descr(..),    Select, @@ -44,6 +47,8 @@ import Data  import Lemmas +------------------------------ TAPES AND BINDINGS ------------------------------ +  type family Tape binds where    Tape '[] = TNil    Tape (t : ts) = TPair t (Tape ts) @@ -186,6 +191,10 @@ reconstructBindings binds tape =               (bconcat (mapBindings fromUnfExpr unf) build)       ,sreverse (stapeUnfoldings binds)) + +--------------------------------- VECTORISATION -------------------------------- +-- Currently only used in D[build1], should be removed. +  type family Vectorise n list where    Vectorise _ '[] = '[]    Vectorise n (t : ts) = TArr n t : Vectorise n ts @@ -231,12 +240,131 @@ vectorise1Binds env n (bs `BPush` (t, e)) =                         (vectoriseExpr SNil (bindingsBinds bs) env e)    in bs' `BPush` (STArr (SS SZ) t, e') + +---------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- + +type Storage :: Symbol -> Type +data Storage s where +  SAccum :: Storage "accum"  -- ^ in the monad state as a mutable accumulator +  SMerge :: Storage "merge"  -- ^ just return and merge +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where +  DTop :: Descr '[] '[] +  DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' +         -> (forall sto'. Descr env' sto' +                       -> Subenv (Select env sto "merge") (Select env' sto' "merge") +                       -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) +                       -> Subenv (D1E env) (D1E env') +                       -> r) +         -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = +  subDescr des sub $ \des' submerge subaccum subd1e -> +    case sto of +      SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) +      SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = +  subDescr des sub $ \des' submerge subaccum subd1e -> +    case sto of +      SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) +      SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) +  -- | Select only the types from the environment that have the specified storage  type family Select env sto s where    Select '[] '[] _ = '[]    Select (t : ts) (s : sto) s = t : Select ts sto s    Select (_ : ts) (_ : sto) s = Select ts sto s +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SAccum (DPush des (_, SMerge)) = select s des +select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) + + +---------------------------------- DERIVATIVES --------------------------------- + +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e +d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e + +-- | Both primal and dual must be duplicable expressions +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) +              | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) + +d2op :: SOp a t -> D2Op a t +d2op op = case op of +  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) +  OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> +    EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) +                              (EOp ext (OMul t) (EPair ext (EFst ext e) d))) +  ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d +  OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  ONot -> Linear $ \_ -> ENil ext +  OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +  OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +  OIf -> Linear $ \_ -> ENil ext +  ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 +  OToFl64 -> Linear $ \_ -> ENil ext +  where +    d2opUnArrangeInt :: SScalTy a +                     -> (D2s a ~ TScal a => D2Op (TScal a) t) +                     -> D2Op (TScal a) t +    d2opUnArrangeInt ty float = case ty of +      STI32 -> Linear $ \_ -> ENil ext +      STI64 -> Linear $ \_ -> ENil ext +      STF32 -> float +      STF64 -> float +      STBool -> Linear $ \_ -> ENil ext + +    d2opBinArrangeInt :: SScalTy a +                      -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) +                      -> D2Op (TPair (TScal a) (TScal a)) t +    d2opBinArrangeInt ty float = case ty of +      STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +      STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +      STF32 -> float +      STF64 -> float +      STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + +sD1eEnv :: Descr env sto -> SList STy (D1E env) +sD1eEnv DTop = SNil +sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) +  conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)  conv1Idx IZ = IZ  conv1Idx (IS i) = IS (conv1Idx i) @@ -250,6 +378,16 @@ conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)  conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)  conv2Idx DTop i = case i of {} + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + + +------------------------------------ MONOIDS ----------------------------------- +  zero :: STy t -> Ex env (D2 t)  zero = EZero  -- TODO: this original definition needs to be used as the post-processing after @@ -312,9 +450,77 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))  zeroTup SNil = ENil ext  zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) -indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) -indexTupD1Id SZ = Refl -indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +------------------------------------ SUBENVS ----------------------------------- + +subenvPlus :: SList STy env +           -> Subenv env env1 -> Subenv env env2 +           -> (forall env3. Subenv env env3 +                         -> Subenv env3 env1 +                         -> Subenv env3 env2 +                         -> (Ex exenv (Tup (D2E env1)) +                             -> Ex exenv (Tup (D2E env2)) +                             -> Ex exenv (Tup (D2E env3))) +                         -> r) +           -> r +subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) +subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = +  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +    k (SENo sub3) s31 s32 pl +subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = +  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +    k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> +      ELet ext e1 $ +        EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) +                      (weakenExpr WSink e2)) +                  (ESnd ext (EVar ext (typeOf e1) IZ)) +subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = +  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +    k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> +      ELet ext e2 $ +        EPair ext (pl (weakenExpr WSink e1) +                      (EFst ext (EVar ext (typeOf e2) IZ))) +                  (ESnd ext (EVar ext (typeOf e2) IZ)) +subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = +  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +    k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> +      ELet ext e1 $ +      ELet ext (weakenExpr WSink e2) $ +        EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) +                      (EFst ext (EVar ext (typeOf e2) IZ))) +                  (plus t +                    (ESnd ext (EVar ext (typeOf e1) (IS IZ))) +                    (ESnd ext (EVar ext (typeOf e2) IZ))) + +expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SETop _ = ENil ext +expandSubenvZeros (SCons t ts) (SEYes sub) e = +  ELet ext e $ +    let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ +    in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) +expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) + +assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] +assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl +assertSubenvEmpty SETop = Refl +assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" + +popFromScope +  :: Descr env0 sto +  -> STy a +  -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub +  -> Ex env (Tup (D2E envSub)) +  -> (forall envSub'. +         Subenv (Select env0 sto "merge") envSub' +      -> Ex env (TPair (Tup (D2E envSub')) (D2 a)) +      -> r) +  -> r +popFromScope _ ty sub e k = case sub of +  SEYes sub' -> k sub' e +  SENo sub' -> k sub' $ EPair ext e (zero ty) + + +--------------------------------- ACCUMULATORS ---------------------------------  accumPromote :: forall dt env sto proxy r.                  proxy dt @@ -411,252 +617,98 @@ uninvertTup (t `SCons` list) tcore e =              (ESnd ext (EVar ext recT IZ))              (ESnd ext (EFst ext (EVar ext recT IZ)))) + +---------------------------- RETURN TRIPLE FROM CHAD --------------------------- +  data Ret env0 sto t = -  forall shbinds env0Merge. +  forall shbinds tapebinds env0Merge.      Ret (Bindings Ex (D1E env0) shbinds)  -- shared binds +        (Subenv shbinds tapebinds)          (Ex (Append shbinds (D1E env0)) (D1 t))          (Subenv (Select env0 sto "merge") env0Merge) -        (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) +        (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))  deriving instance Show (Ret env0 sto t) -data RetPair env0 sto env shbinds t = +data RetPair env0 sto env shbinds tapebinds t =    forall env0Merge.      RetPair (Ex (Append shbinds env) (D1 t))              (Subenv (Select env0 sto "merge") env0Merge) -            (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds t) +            (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) +deriving instance Show (RetPair env0 sto env shbinds tapebinds t)  data Rets env0 sto env list = -  forall shbinds. +  forall shbinds tapebinds.      Rets (Bindings Ex env shbinds) -         (SList (RetPair env0 sto env shbinds) list) +         (Subenv shbinds tapebinds) +         (SList (RetPair env0 sto env shbinds tapebinds) list)  deriving instance Show (Rets env0 sto env list) -subenvPlus :: SList STy env -           -> Subenv env env1 -> Subenv env env2 -           -> (forall env3. Subenv env env3 -                         -> Subenv env3 env1 -                         -> Subenv env3 env2 -                         -> (Ex exenv (Tup (D2E env1)) -                             -> Ex exenv (Tup (D2E env2)) -                             -> Ex exenv (Tup (D2E env3))) -                         -> r) -           -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> -      ELet ext e1 $ -        EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) -                      (weakenExpr WSink e2)) -                  (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> -      ELet ext e2 $ -        EPair ext (pl (weakenExpr WSink e1) -                      (EFst ext (EVar ext (typeOf e2) IZ))) -                  (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> -      ELet ext e1 $ -      ELet ext (weakenExpr WSink e2) $ -        EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) -                      (EFst ext (EVar ext (typeOf e2) IZ))) -                  (plus t -                    (ESnd ext (EVar ext (typeOf e1) (IS IZ))) -                    (ESnd ext (EVar ext (typeOf e2) IZ))) - -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = -  ELet ext e $ -    let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ -    in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) - -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" - -popFromScope -  :: Descr env0 sto -  -> STy a -  -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub -  -> Ex env (Tup (D2E envSub)) -  -> (forall envSub'. -         Subenv (Select env0 sto "merge") envSub' -      -> Ex env (TPair (Tup (D2E envSub')) (D2 a)) -      -> r) -  -> r -popFromScope _ ty sub e k = case sub of -  SEYes sub' -> k sub' e -  SENo sub' -> k sub' $ EPair ext e (zero ty) - --- d1W :: env :> env' -> D1E env :> D1E env' --- d1W WId = WId --- d1W WSink = WSink --- d1W (WCopy w) = WCopy (d1W w) --- d1W (WPop w) = WPop (d1W w) --- d1W (WThen u w) = WThen (d1W u) (d1W w) - -weakenRetPair :: SList STy shbinds -> env :> env' -> RetPair env0 sto env shbinds t -> RetPair env0 sto env' shbinds t +weakenRetPair :: SList STy shbinds -> env :> env' +              -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t  weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2  weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list -weakenRets w (Rets binds list) = +weakenRets w (Rets binds tapesub list) =    let (binds', _) = weakenBindings weakenExpr w binds -  in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list) +  in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) -rebaseRetPair :: forall env b1 b2 env0 sto t f. -                 Descr env0 sto -> SList f b1 -> SList f b2 -              -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t -rebaseRetPair descr b1 b2 (RetPair p sub d) +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. +                 Descr env0 sto +              -> SList f b1 -> SList f b2 +              -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 +              -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t +              -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)    | Refl <- lemAppendAssoc @b2 @b1 @env =        RetPair p sub (weakenExpr (autoWeak -                                  (#d (auto1 @(D2 t)) &. #b2 b2 &. #b1 b1 &. #tl (d2ace (select SAccum descr))) -                                  (#d :++: (#b2 :++: #tl)) -                                  (#d :++: ((#b2 :++: #b1) :++: #tl))) +                                  (#d (auto1 @(D2 t)) +                                   &. #t2 (subList b2 subtape2) +                                   &. #t1 (subList b1 subtape1) +                                   &. #tl (d2ace (select SAccum descr))) +                                  (#d :++: (#t2 :++: #tl)) +                                  (#d :++: ((#t2 :++: #t1) :++: #tl)))                                  d)  retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat _ SNil = Rets BTop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) -  | Rets binds1 pairs1 <- retConcat descr list -  , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1) -  , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0) -  , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum")) +retConcat _ SNil = Rets BTop SETop SNil +retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) +  | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs +      <- weakenRets (sinkWithBindings b) (retConcat descr list) +  , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) +  , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))    = Rets (bconcat b binds) +         (subenvConcat subtape subtape2)           (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)                           sub -                         (weakenExpr (WCopy (sinkWithBindings binds)) d)) -                (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)) pairs)) - -d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) -d1op (OAdd t) e = EOp ext (OAdd t) e -d1op (OMul t) e = EOp ext (OMul t) e -d1op (ONeg t) e = EOp ext (ONeg t) e -d1op (OLt t) e = EOp ext (OLt t) e -d1op (OLe t) e = EOp ext (OLe t) e -d1op (OEq t) e = EOp ext (OEq t) e -d1op ONot e = EOp ext ONot e -d1op OAnd e = EOp ext OAnd e -d1op OOr e = EOp ext OOr e -d1op OIf e = EOp ext OIf e -d1op ORound64 e = EOp ext ORound64 e -d1op OToFl64 e = EOp ext OToFl64 e - --- | Both primal and dual must be duplicable expressions -data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) -              | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) - -d2op :: SOp a t -> D2Op a t -d2op op = case op of -  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) -  OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> -    EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) -                              (EOp ext (OMul t) (EPair ext (EFst ext e) d))) -  ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d -  OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  ONot -> Linear $ \_ -> ENil ext -  OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -  OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -  OIf -> Linear $ \_ -> ENil ext -  ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 -  OToFl64 -> Linear $ \_ -> ENil ext -  where -    d2opUnArrangeInt :: SScalTy a -                     -> (D2s a ~ TScal a => D2Op (TScal a) t) -                     -> D2Op (TScal a) t -    d2opUnArrangeInt ty float = case ty of -      STI32 -> Linear $ \_ -> ENil ext -      STI64 -> Linear $ \_ -> ENil ext -      STF32 -> float -      STF64 -> float -      STBool -> Linear $ \_ -> ENil ext - -    d2opBinArrangeInt :: SScalTy a -                      -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -                      -> D2Op (TPair (TScal a) (TScal a)) t -    d2opBinArrangeInt ty float = case ty of -      STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -      STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -      STF32 -> float -      STF64 -> float -      STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - -type Storage :: Symbol -> Type -data Storage s where -  SAccum :: Storage "accum"  -- ^ in the monad state as a mutable accumulator -  SMerge :: Storage "merge"  -- ^ just return and merge -deriving instance Show (Storage s) - --- | Environment description -data Descr env sto where -  DTop :: Descr '[] '[] -  DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) -deriving instance Show (Descr env sto) - -descrList :: Descr env sto -> SList STy env -descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des - -select :: Storage s -> Descr env sto -> SList STy (Select env sto s) -select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) - --- | This could have more precise typing on the output storage. -subDescr :: Descr env sto -> Subenv env env' -         -> (forall sto'. Descr env' sto' -                       -> Subenv (Select env sto "merge") (Select env' sto' "merge") -                       -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) -                       -> Subenv (D1E env) (D1E env') -                       -> r) -         -> r -subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = -  subDescr des sub $ \des' submerge subaccum subd1e -> -    case sto of -      SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) -      SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = -  subDescr des sub $ \des' submerge subaccum subd1e -> -    case sto of -      SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) -      SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) - -sD1eEnv :: Descr env sto -> SList STy (D1E env) -sD1eEnv DTop = SNil -sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) +                         (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) +                (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) +                                               subtape subtape2) +                          pairs))  freezeRet :: Descr env sto            -> Ret env sto t            -> Ex (D1E env) (D2 t)  -- the incoming cotangent value            -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 e1 sub e2) d = +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) d =    let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0 -      e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 +      e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2    in letBinds e0' $         EPair ext           (weakenExpr wInsertD2Ac e1)           (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $ -          ELet ext e2' $ +          ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) +                                          &. #tape (subList (bindingsBinds e0) subtape) +                                          &. #shbinds (bindingsBinds e0) +                                          &. #d2ace (d2ace (select SAccum descr)) +                                          &. #tl (sD1eEnv descr)) +                                         (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) +                                         (#d :++: #shbinds :++: #d2ace :++: #tl)) +                      e2') $            expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) + +---------------------------- THE CHAD TRANSFORMATION --------------------------- +  drev :: forall env sto t.          Descr env sto       -> Ex env t -> Ret env sto t @@ -665,50 +717,55 @@ drev des = \case      case conv2Idx des i of        Left accI ->          Ret BTop +            SETop              (EVar ext (d1 t) (conv1Idx i))              (subenvNone (select SMerge des))              (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))        Right tupI ->          Ret BTop +            SETop              (EVar ext (d1 t) (conv1Idx i))              (subenvOnehot (select SMerge des) tupI)              (EPair ext (ENil ext) (EVar ext (d2 t) IZ))    ELet _ rhs body -    | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs -    , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body +    | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 +        <- drev des rhs +    , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 +        <- drev (des `DPush` (typeOf rhs, SMerge)) body      , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0      , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) -    , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum")) -    , Refl <- lemAppendNil @body_shbinds -> +    , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->      popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' ->      subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->      let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in      Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') +        (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)          (weakenExpr wbody0' body1)          subBoth          (ELet ext             (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                  &. #body (bindingsBinds body0) -                                  &. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0)) +                                  &. #body (subList (bindingsBinds body0) subtapeBody) +                                  &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)                                    &. #tl (d2ace (select SAccum des)))                                   (#d :++: #body :++: #tl) -                                 (#d :++: #body :++: #rhs :++: #tl)) +                                 (#d :++: (#body :++: #rhs) :++: #tl))                         body2') $           ELet ext             (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ -             weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $ +             weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $           plus_RHS_Body             (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)             (EFst ext (EVar ext bodyResType (IS IZ))))    EPair _ a b -    | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) +    | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)          <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil      , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->      subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->      Ret binds +        subtape          (EPair ext a1 b1)          subBoth          (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) @@ -722,28 +779,31 @@ drev des = \case               (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))    EFst _ e -    | Ret e0 e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des e      , STPair t1 t2 <- typeOf e ->      Ret e0 +        subtape          (EFst ext e1)          sub          (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $             weakenExpr (WCopy WSink) e2)    ESnd _ e -    | Ret e0 e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des e      , STPair t1 t2 <- typeOf e ->      Ret e0 +        subtape          (ESnd ext e1)          sub          (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $             weakenExpr (WCopy WSink) e2) -  ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) +  ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)    EInl _ t2 e -    | Ret e0 e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des e ->      Ret e0 +        subtape          (EInl ext (d1 t2) e1)          sub          (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) @@ -753,8 +813,9 @@ drev des = \case                (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))    EInr _ t1 e -    | Ret e0 e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des e ->      Ret e0 +        subtape          (EInr ext (d1 t1) e1)          sub          (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) @@ -763,6 +824,7 @@ drev des = \case                (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")                (weakenExpr (WCopy (wSinks' @[_,_])) e2))) +{-    ECase _ e (a :: Ex _ t) b      | STEither t1 t2 <- typeOf e      , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e @@ -962,7 +1024,6 @@ drev des = \case      case assertSubenvEmpty sub of { Refl ->      let tapety = tapeTy (bindingsBinds e0) in      let collectexpr = bindingsCollect e0 in -    -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in      Ret (she0 `BPush` (shty, she1)                `BPush` (STArr ndim tapety                        ,EBuild ext ndim @@ -1129,6 +1190,7 @@ drev des = \case    EAccum{} -> err_accum    EZero{} -> err_monoid    EPlus{} -> err_monoid +-}    where      err_accum = error "Accumulator operations unsupported in the source program" | 
