summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 17:48:15 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 17:48:15 +0200
commit8b047ff11ebd4715647bfc041a190f72dcf4d5a9 (patch)
treee8440120b7bbd4e45b367acb3f7185d25e7f3766 /src/CHAD.hs
parentf4b94d7cc2cb05611b462ba278e4f12f7a7a5e5e (diff)
Migrate to accumulators (mostly removing EVM code)
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs281
1 files changed, 134 insertions, 147 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 2513f84..e209b67 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1,16 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE UndecidableInstances #-}
-- I want to bring various type variables in scope using type annotations in
@@ -29,7 +29,6 @@ module CHAD (
import Data.Bifunctor (first, second)
import Data.Functor.Const
import Data.Kind (Type)
-import Data.Some
import GHC.TypeLits (Symbol)
import AST
@@ -254,7 +253,7 @@ type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
- -- D2 (TArr n t) = _
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
@@ -264,6 +263,9 @@ type family D2s t where
D2s TF64 = TScal TF64
D2s TBool = TNil
+type family D2Ac t where
+ D2Ac (TArr n t) = TAccum n t
+
type family D1E env where
D1E '[] = '[]
D1E (t : env) = D1 t : D1E env
@@ -272,6 +274,10 @@ type family D2E env where
D2E '[] = '[]
D2E (t : env) = D2 t : D2E env
+type family D2AcE env where
+ D2AcE '[] = '[]
+ D2AcE (t : env) = D2Ac t : D2AcE env
+
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
Select '[] '[] _ = '[]
@@ -284,20 +290,24 @@ d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
-d1 STEVM{} = error "EVM not allowed in input program"
+d1 STAccum{} = error "Accumulators not allowed in input program"
d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
-d2 STArr{} = error "TODO arrays"
+d2 (STArr n t) = STArr n (d2 t)
d2 (STScal t) = case t of
STI32 -> STNil
STI64 -> STNil
STF32 -> STScal STF32
STF64 -> STScal STF64
STBool -> STNil
-d2 STEVM{} = error "EVM not allowed in input program"
+d2 STAccum{} = error "Accumulators not allowed in input program"
+
+d2ac :: STy t -> STy (D2Ac t)
+d2ac (STArr n t) = STAccum n t
+d2ac _ = error "Only arrays may appear in the accumulator environment"
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
@@ -322,7 +332,7 @@ zero (STScal t) = case t of
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
STBool -> ENil ext
-zero STEVM{} = error "EVM not allowed in input program"
+zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
@@ -350,7 +360,7 @@ plus (STScal t) a b = case t of
STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
STBool -> ENil ext
-plus STEVM{} _ _ = error "EVM not allowed in input program"
+plus STAccum{} _ _ = error "Accumulators not allowed in input program"
plusSparse :: STy a
-> Ex env (TEither TNil a) -> Ex env (TEither TNil a)
@@ -388,14 +398,14 @@ data Ret env0 sto t =
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Ex (Append shbinds (D1E env0)) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge))))
+ (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
deriving instance Show (Ret env0 sto t)
data RetPair env0 sto env shbinds t =
forall env0Merge.
RetPair (Ex (Append shbinds env) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge))))
+ (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
deriving instance Show (RetPair env0 sto env shbinds t)
data Rets env0 sto env list =
@@ -430,7 +440,8 @@ subenvPlus :: SList STy env
-> r
subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SENo sub3) s31 s32 pl
subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
@@ -464,24 +475,19 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =
in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
-unscope :: Descr env0 sto
- -> STy a -> Storage s
- -> Subenv (Select (a : env0) (s : sto) "merge") envSub
- -> Ex env (TEVM (D2E (Select (a : env0) (s : sto) "accum")) (Tup (D2E envSub)))
- -> (forall envSub'.
- Subenv (Select env0 sto "merge") envSub'
- -> Ex env (TEVM (D2E (Select env0 sto "accum")) (TPair (Tup (D2E envSub')) (D2 a)))
- -> r)
- -> r
-unscope des ty s sub e k = case s of
- SAccum -> k sub (EMScope e)
- SMerge -> case sub of
- SEYes sub' -> k sub' e
- SENo sub' -> k sub' $
- EMBind e $
- EMReturn (d2e (select SAccum des)) $
- EPair ext (EVar ext (tTup (d2e (subList (select SMerge des) sub'))) IZ)
- (zero ty)
+popFromScope
+ :: Descr env0 sto
+ -> STy a
+ -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
+ -> Ex env (Tup (D2E envSub))
+ -> (forall envSub'.
+ Subenv (Select env0 sto "merge") envSub'
+ -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
+ -> r)
+ -> r
+popFromScope _ ty sub e k = case sub of
+ SEYes sub' -> k sub' e
+ SENo sub' -> k sub' $ EPair ext e (zero ty)
-- d1W :: env :> env' -> D1E env :> D1E env'
-- d1W WId = WId
@@ -501,7 +507,7 @@ weakenRets w (Rets binds list) =
rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
rebaseRetPair b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 b1)) d)
+ RetPair p sub (weakenExpr (WCopy (wStack @(D2AcE (Select env0 sto "accum")) (wRaiseAbove b2 b1))) d)
retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat SNil = Rets BTop SNil
@@ -509,37 +515,12 @@ retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
| Rets binds1 pairs1 <- retConcat list
, Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1)
, Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0)
+ , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum"))
= Rets (bconcat b binds)
(SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
sub
(weakenExpr (WCopy (sinkWithBindings binds)) d))
(slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs))
--- list ~ a : list'
--- SCons (Ret b p sub d) list :: SList (Ret env0 sto) list
--- Ret b p sub d :: Ret env0 sto a <- existential shbinds
--- b :: Bindings Ex (D1E env0) shbinds
--- p :: Ex (Append shbinds (D1E env0)) (D1 a)
--- d :: Ex (D2 a : shbinds) (TEVM ...)
---
--- list :: SList (Ret env0 sto) list'
--- retConcat list :: Rets env0 sto (D1E env0) list' <- existential shbinds1
--- binds1 :: Bindings Ex (D1E env0) shbinds1
--- pairs1 :: SList (RetPair env0 sto (D1E env0) shbinds1) list'
---
--- sinkWithBindings b :: forall e. e :> Append shbinds e
--- Rets binds pairs :: Rets env0 sto (Append shbinds (D1E env0)) list' <- existential shbinds2
--- binds :: Bindings Ex (Append shbinds (D1E env0)) shbinds2
--- pairs :: SList (RetPair env0 sto (Append shbinds (D1E env0)) shbinds2) list'
---
--- we choose shbindsR ~ Append shbinds2 shbinds
--- result :: Rets env0 sto (D1E env0) list
--- result.1 :: Bindings Ex (D1E env0) shbindsR == Bindings Ex (D1E env0) (Append shbinds2 shbinds)
--- result.2 :: SList (RetPair env0 sto (D1E env0) shbindsR) list
--- result.2.head :: RetPair env0 sto (D1E env0) shbindsR a
--- result.2.tail :: SList (RetPair env0 sto (D1E env0) shbindsR) list'
--- = SList (RetPair env0 sto (D1E env0) (Append shbinds2 shbinds)) list'
---
--- wanted: shbinds1 :> shbindsR
d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
d1op (OAdd t) e = EOp ext (OAdd t) e
@@ -557,7 +538,7 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
(EOp ext (OMul t) (EPair ext (EFst ext e) d)))
@@ -611,86 +592,89 @@ sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d)
+d2e :: SList STy env -> SList STy (D2E env)
+d2e SNil = SNil
+d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts)
+
freezeRet :: Descr env sto
-> Ret env sto t
-> Ex (D1E env) (D2 t) -- the incoming cotangent value
- -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge")))))
+ -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
freezeRet descr (Ret e0 e1 sub e2) d =
- let e2' = weakenExpr (WCopy (wRaiseAbove (bindingsBinds e0) (sD1eEnv descr))) e2
- in letBinds e0 $
+ let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0
+ e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
+ in letBinds e0' $
EPair ext
- e1
- (ELet ext (weakenExpr (sinkWithBindings e0) d)
- (EMBind e2'
- (EMReturn (d2e (select SAccum descr))
- (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)))))
-
-d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+ (weakenExpr wInsertD2Ac e1)
+ (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $
+ ELet ext e2' $
+ expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
drev :: forall env sto t.
Descr env sto
- -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage)
-> Ex env t -> Ret env sto t
-drev des policy = \case
+drev des = \case
EVar _ t i ->
case conv2Idx des i of
- Left accumI ->
+ Left _ ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
- (EMOne d2acc accumI (EVar ext (d2 t) IZ))
+ (ENil ext)
Right tupI ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(subenvOnehot (select SMerge des) tupI)
- (EMReturn d2acc (EPair ext (ENil ext) (EVar ext (d2 t) IZ)))
+ (EPair ext (ENil ext) (EVar ext (d2 t) IZ))
ELet _ rhs body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des policy rhs
- , Some storage <- policy des (typeOf rhs)
- , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
+ , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendNil @body_shbinds ->
- unscope des (typeOf rhs) storage subBody body2 $ \subBody' body2' ->
+ popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' ->
subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(weakenExpr wbody0' body1)
subBoth
- (EMBind
- (weakenExpr (WCopy (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))) body2')
- (EMBind
- (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2)
- (EMReturn d2acc (plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
- (EFst ext (EVar ext bodyResType (IS IZ)))))))
+ (ELet ext
+ (weakenExpr (WCopy (wStack @(D2AcE (Select env sto "accum")) (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))))
+ body2') $
+ ELet ext
+ (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $
+ plus_RHS_Body
+ (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
| Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil
+ <- retConcat $ drev des a `SCons` drev des b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
(EPair ext a1 b1)
subBoth
(ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
- (EMReturn d2acc (zeroTup (subList (select SMerge des) subBoth)))
- (EMBind (ELet ext (EFst ext (EVar ext dt IZ))
+ (zeroTup (subList (select SMerge des) subBoth))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
(weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
(weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- EMReturn d2acc
- (plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))))
+ plus_A_B
+ (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
+ (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))
EFst _ e
- | Ret e0 e1 sub e2 <- drev des policy e
+ | Ret e0 e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
(EFst ext e1)
@@ -699,7 +683,7 @@ drev des policy = \case
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 e1 sub e2 <- drev des policy e
+ | Ret e0 e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
(ESnd ext e1)
@@ -707,46 +691,47 @@ drev des policy = \case
(ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (EMReturn d2acc (ENil ext))
+ ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
EInl _ t2 e
- | Ret e0 e1 sub e2 <- drev des policy e ->
+ | Ret e0 e1 sub e2 <- drev des e ->
Ret e0
(EInl ext (d1 t2) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
- (EMReturn d2acc (zeroTup (subList (select SMerge des) sub)))
+ (zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks' @[_,_])) e2)
- (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inl<-dinr")))
+ (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))
EInr _ t1 e
- | Ret e0 e1 sub e2 <- drev des policy e ->
+ | Ret e0 e1 sub e2 <- drev des e ->
Ret e0
(EInr ext (d1 t1) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
- (EMReturn d2acc (zeroTup (subList (select SMerge des) sub)))
+ (zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
- (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl")
+ (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des policy e
- , Some storageA <- policy des t1
- , Some storageB <- policy des t2
- , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a
- , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b
+ , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e
+ , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a
+ , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b
+ , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
+ , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (bindingsBinds a0)
, let tapeB = tapeTy (bindingsBinds b0)
, let collectA = bindingsCollect a0
, let collectB = bindingsCollect b0
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
- , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 ->
- unscope des t1 storageA subA a2 $ \subA' a2' ->
- unscope des t2 storageB subB b2 $ \subB' b2' ->
+ , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
+ ->
+ popFromScope des t1 subA a2 $ \subA' a2' ->
+ popFromScope des t2 subB b2 $ \subB' b2' ->
subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ ->
subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in
@@ -757,43 +742,49 @@ drev des policy = \case
(letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
(EFst ext (EVar ext tPrimal IZ))
subOut
- (EMBind
+ (ELet ext
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
(let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ
in letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
- EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) a2')
- (EMReturn d2acc
- (EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
- (EInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))))
+ ELet ext
+ (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
+ ELet ext
+ (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $
+ WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds))
+ a2') $
+ EPair ext
+ (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
+ (EInl ext (d2 t2)
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))
(let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ
in letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
- EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) b2')
- (EMReturn d2acc
- (EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
- (EInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)))))))
- (EMBind (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
- weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
- EMReturn d2acc $
- plus_AB_E
- (EFst ext (EVar ext tCaseRet (IS IZ)))
- (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)))
+ ELet ext
+ (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
+ ELet ext
+ (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $
+ WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds))
+ b2') $
+ EPair ext
+ (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
+ (EInr ext (d2 t1)
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))) $
+ ELet ext
+ (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
+ plus_AB_E
+ (EFst ext (EVar ext tCaseRet (IS IZ)))
+ (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))
EConst _ t val ->
Ret BTop
(EConst ext t val)
(subenvNone (select SMerge des))
- (EMReturn d2acc (ENil ext))
+ (ENil ext)
EOp _ op e
- | Ret e0 e1 sub e2 <- drev des policy e ->
+ | Ret e0 e1 sub e2 <- drev des e ->
case d2op op of
Linear d2opfun ->
Ret e0
@@ -813,7 +804,7 @@ drev des policy = \case
Ret BTop
(EError (d1 t) s)
(subenvNone (select SMerge des))
- (EMReturn d2acc (ENil ext))
+ (ENil ext)
-- These should be the next to be implemented, I think
EBuild1{} -> err_unsupported "EBuild1"
@@ -823,13 +814,9 @@ drev des policy = \case
EBuild{} -> err_unsupported "EBuild"
EIdx{} -> err_unsupported "EIdx"
- EMOne{} -> err_evm
- EMScope{} -> err_evm
- EMReturn{} -> err_evm
- EMBind{} -> err_evm
+ EWith{} -> err_accum
+ EAccum{} -> err_accum
where
- d2acc = d2e (select SAccum des)
-
- err_evm = error "EVM operations unsupported in the source program"
+ err_accum = error "Accumulator operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s