diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 17:48:15 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 17:48:15 +0200 |
commit | 8b047ff11ebd4715647bfc041a190f72dcf4d5a9 (patch) | |
tree | e8440120b7bbd4e45b367acb3f7185d25e7f3766 /src/CHAD.hs | |
parent | f4b94d7cc2cb05611b462ba278e4f12f7a7a5e5e (diff) |
Migrate to accumulators (mostly removing EVM code)
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 281 |
1 files changed, 134 insertions, 147 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 2513f84..e209b67 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1,16 +1,16 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE UndecidableInstances #-} -- I want to bring various type variables in scope using type annotations in @@ -29,7 +29,6 @@ module CHAD ( import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) -import Data.Some import GHC.TypeLits (Symbol) import AST @@ -254,7 +253,7 @@ type family D2 t where D2 TNil = TNil D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) - -- D2 (TArr n t) = _ + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where @@ -264,6 +263,9 @@ type family D2s t where D2s TF64 = TScal TF64 D2s TBool = TNil +type family D2Ac t where + D2Ac (TArr n t) = TAccum n t + type family D1E env where D1E '[] = '[] D1E (t : env) = D1 t : D1E env @@ -272,6 +274,10 @@ type family D2E env where D2E '[] = '[] D2E (t : env) = D2 t : D2E env +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = D2Ac t : D2AcE env + -- | Select only the types from the environment that have the specified storage type family Select env sto s where Select '[] '[] _ = '[] @@ -284,20 +290,24 @@ d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t -d1 STEVM{} = error "EVM not allowed in input program" +d1 STAccum{} = error "Accumulators not allowed in input program" d2 :: STy t -> STy (D2 t) d2 STNil = STNil d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) -d2 STArr{} = error "TODO arrays" +d2 (STArr n t) = STArr n (d2 t) d2 (STScal t) = case t of STI32 -> STNil STI64 -> STNil STF32 -> STScal STF32 STF64 -> STScal STF64 STBool -> STNil -d2 STEVM{} = error "EVM not allowed in input program" +d2 STAccum{} = error "Accumulators not allowed in input program" + +d2ac :: STy t -> STy (D2Ac t) +d2ac (STArr n t) = STAccum n t +d2ac _ = error "Only arrays may appear in the accumulator environment" conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ @@ -322,7 +332,7 @@ zero (STScal t) = case t of STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 STBool -> ENil ext -zero STEVM{} = error "EVM not allowed in input program" +zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) plus STNil _ _ = ENil ext @@ -350,7 +360,7 @@ plus (STScal t) a b = case t of STF32 -> EOp ext (OAdd STF32) (EPair ext a b) STF64 -> EOp ext (OAdd STF64) (EPair ext a b) STBool -> ENil ext -plus STEVM{} _ _ = error "EVM not allowed in input program" +plus STAccum{} _ _ = error "Accumulators not allowed in input program" plusSparse :: STy a -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) @@ -388,14 +398,14 @@ data Ret env0 sto t = Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Ex (Append shbinds (D1E env0)) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) + (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) deriving instance Show (Ret env0 sto t) data RetPair env0 sto env shbinds t = forall env0Merge. RetPair (Ex (Append shbinds env) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) + (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) deriving instance Show (RetPair env0 sto env shbinds t) data Rets env0 sto env list = @@ -430,7 +440,8 @@ subenvPlus :: SList STy env -> r subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl + subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> @@ -464,24 +475,19 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) -unscope :: Descr env0 sto - -> STy a -> Storage s - -> Subenv (Select (a : env0) (s : sto) "merge") envSub - -> Ex env (TEVM (D2E (Select (a : env0) (s : sto) "accum")) (Tup (D2E envSub))) - -> (forall envSub'. - Subenv (Select env0 sto "merge") envSub' - -> Ex env (TEVM (D2E (Select env0 sto "accum")) (TPair (Tup (D2E envSub')) (D2 a))) - -> r) - -> r -unscope des ty s sub e k = case s of - SAccum -> k sub (EMScope e) - SMerge -> case sub of - SEYes sub' -> k sub' e - SENo sub' -> k sub' $ - EMBind e $ - EMReturn (d2e (select SAccum des)) $ - EPair ext (EVar ext (tTup (d2e (subList (select SMerge des) sub'))) IZ) - (zero ty) +popFromScope + :: Descr env0 sto + -> STy a + -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub + -> Ex env (Tup (D2E envSub)) + -> (forall envSub'. + Subenv (Select env0 sto "merge") envSub' + -> Ex env (TPair (Tup (D2E envSub')) (D2 a)) + -> r) + -> r +popFromScope _ ty sub e k = case sub of + SEYes sub' -> k sub' e + SENo sub' -> k sub' $ EPair ext e (zero ty) -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId @@ -501,7 +507,7 @@ weakenRets w (Rets binds list) = rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t rebaseRetPair b1 b2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 b1)) d) + RetPair p sub (weakenExpr (WCopy (wStack @(D2AcE (Select env0 sto "accum")) (wRaiseAbove b2 b1))) d) retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list retConcat SNil = Rets BTop SNil @@ -509,37 +515,12 @@ retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) | Rets binds1 pairs1 <- retConcat list , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1) , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0) + , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum")) = Rets (bconcat b binds) (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) sub (weakenExpr (WCopy (sinkWithBindings binds)) d)) (slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs)) --- list ~ a : list' --- SCons (Ret b p sub d) list :: SList (Ret env0 sto) list --- Ret b p sub d :: Ret env0 sto a <- existential shbinds --- b :: Bindings Ex (D1E env0) shbinds --- p :: Ex (Append shbinds (D1E env0)) (D1 a) --- d :: Ex (D2 a : shbinds) (TEVM ...) --- --- list :: SList (Ret env0 sto) list' --- retConcat list :: Rets env0 sto (D1E env0) list' <- existential shbinds1 --- binds1 :: Bindings Ex (D1E env0) shbinds1 --- pairs1 :: SList (RetPair env0 sto (D1E env0) shbinds1) list' --- --- sinkWithBindings b :: forall e. e :> Append shbinds e --- Rets binds pairs :: Rets env0 sto (Append shbinds (D1E env0)) list' <- existential shbinds2 --- binds :: Bindings Ex (Append shbinds (D1E env0)) shbinds2 --- pairs :: SList (RetPair env0 sto (Append shbinds (D1E env0)) shbinds2) list' --- --- we choose shbindsR ~ Append shbinds2 shbinds --- result :: Rets env0 sto (D1E env0) list --- result.1 :: Bindings Ex (D1E env0) shbindsR == Bindings Ex (D1E env0) (Append shbinds2 shbinds) --- result.2 :: SList (RetPair env0 sto (D1E env0) shbindsR) list --- result.2.head :: RetPair env0 sto (D1E env0) shbindsR a --- result.2.tail :: SList (RetPair env0 sto (D1E env0) shbindsR) list' --- = SList (RetPair env0 sto (D1E env0) (Append shbinds2 shbinds)) list' --- --- wanted: shbinds1 :> shbindsR d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) d1op (OAdd t) e = EOp ext (OAdd t) e @@ -557,7 +538,7 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) (EOp ext (OMul t) (EPair ext (EFst ext e) d))) @@ -611,86 +592,89 @@ sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d) +d2e :: SList STy env -> SList STy (D2E env) +d2e SNil = SNil +d2e (SCons t ts) = SCons (d2 t) (d2e ts) + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts) + freezeRet :: Descr env sto -> Ret env sto t -> Ex (D1E env) (D2 t) -- the incoming cotangent value - -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) + -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 e1 sub e2) d = - let e2' = weakenExpr (WCopy (wRaiseAbove (bindingsBinds e0) (sD1eEnv descr))) e2 - in letBinds e0 $ + let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0 + e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 + in letBinds e0' $ EPair ext - e1 - (ELet ext (weakenExpr (sinkWithBindings e0) d) - (EMBind e2' - (EMReturn (d2e (select SAccum descr)) - (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))))) - -d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) + (weakenExpr wInsertD2Ac e1) + (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $ + ELet ext e2' $ + expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) drev :: forall env sto t. Descr env sto - -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage) -> Ex env t -> Ret env sto t -drev des policy = \case +drev des = \case EVar _ t i -> case conv2Idx des i of - Left accumI -> + Left _ -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (EMOne d2acc accumI (EVar ext (d2 t) IZ)) + (ENil ext) Right tupI -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) - (EMReturn d2acc (EPair ext (ENil ext) (EVar ext (d2 t) IZ))) + (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) ELet _ rhs body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des policy rhs - , Some storage <- policy des (typeOf rhs) - , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body + | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs + , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendNil @body_shbinds -> - unscope des (typeOf rhs) storage subBody body2 $ \subBody' body2' -> + popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' -> subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (weakenExpr wbody0' body1) subBoth - (EMBind - (weakenExpr (WCopy (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))) body2') - (EMBind - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) - (EMReturn d2acc (plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) - (EFst ext (EVar ext bodyResType (IS IZ))))))) + (ELet ext + (weakenExpr (WCopy (wStack @(D2AcE (Select env sto "accum")) (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0))))) + body2') $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $ + plus_RHS_Body + (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil + <- retConcat $ drev des a `SCons` drev des b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> Ret binds (EPair ext a1 b1) subBoth (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) - (EMReturn d2acc (zeroTup (subList (select SMerge des) subBoth))) - (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) + (zeroTup (subList (select SMerge des) subBoth)) + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - EMReturn d2acc - (plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))) + plus_A_B + (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) + (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))) EFst _ e - | Ret e0 e1 sub e2 <- drev des policy e + | Ret e0 e1 sub e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 (EFst ext e1) @@ -699,7 +683,7 @@ drev des policy = \case weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 e1 sub e2 <- drev des policy e + | Ret e0 e1 sub e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 (ESnd ext e1) @@ -707,46 +691,47 @@ drev des policy = \case (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (EMReturn d2acc (ENil ext)) + ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) EInl _ t2 e - | Ret e0 e1 sub e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des e -> Ret e0 (EInl ext (d1 t2) e1) sub (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) - (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) + (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inl<-dinr"))) + (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))) EInr _ t1 e - | Ret e0 e1 sub e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des e -> Ret e0 (EInr ext (d1 t1) e1) sub (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) - (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) + (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl") + (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") (weakenExpr (WCopy (wSinks' @[_,_])) e2))) ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des policy e - , Some storageA <- policy des t1 - , Some storageB <- policy des t2 - , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a - , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b + , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e + , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a + , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b + , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (bindingsBinds a0) , let tapeB = tapeTy (bindingsBinds b0) , let collectA = bindingsCollect a0 , let collectB = bindingsCollect b0 , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 -> - unscope des t1 storageA subA a2 $ \subA' a2' -> - unscope des t2 storageB subB b2 $ \subB' b2' -> + , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + -> + popFromScope des t1 subA a2 $ \subA' a2' -> + popFromScope des t2 subB b2 $ \subB' b2' -> subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ -> subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in @@ -757,43 +742,49 @@ drev des policy = \case (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) (EFst ext (EVar ext tPrimal IZ)) subOut - (EMBind + (ELet ext (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ in letBinds rebinds $ - ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ - EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) a2') - (EMReturn d2acc - (EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)) - (EInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)))))) + ELet ext + (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ + ELet ext + (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $ + WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) + a2') $ + EPair ext + (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ + EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)) + (EInl ext (d2 t2) + (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)))) (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ in letBinds rebinds $ - ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ - EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) b2') - (EMReturn d2acc - (EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)) - (EInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))))) - (EMBind (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ - EMReturn d2acc $ - plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))) + ELet ext + (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ + ELet ext + (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $ + WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) + b2') $ + EPair ext + (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ + EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)) + (EInr ext (d2 t1) + (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))) $ + ELet ext + (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ + plus_AB_E + (EFst ext (EVar ext tCaseRet (IS IZ))) + (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) EConst _ t val -> Ret BTop (EConst ext t val) (subenvNone (select SMerge des)) - (EMReturn d2acc (ENil ext)) + (ENil ext) EOp _ op e - | Ret e0 e1 sub e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des e -> case d2op op of Linear d2opfun -> Ret e0 @@ -813,7 +804,7 @@ drev des policy = \case Ret BTop (EError (d1 t) s) (subenvNone (select SMerge des)) - (EMReturn d2acc (ENil ext)) + (ENil ext) -- These should be the next to be implemented, I think EBuild1{} -> err_unsupported "EBuild1" @@ -823,13 +814,9 @@ drev des policy = \case EBuild{} -> err_unsupported "EBuild" EIdx{} -> err_unsupported "EIdx" - EMOne{} -> err_evm - EMScope{} -> err_evm - EMReturn{} -> err_evm - EMBind{} -> err_evm + EWith{} -> err_accum + EAccum{} -> err_accum where - d2acc = d2e (select SAccum des) - - err_evm = error "EVM operations unsupported in the source program" + err_accum = error "Accumulator operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s |