diff options
author | Tom Smeding <tom@tomsmeding.com> | 2023-09-21 23:57:20 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2023-09-21 23:57:38 +0200 |
commit | 3266269f4636a491f74ccf72b02db7cbb5acf26c (patch) | |
tree | ace7ee902c01c8dd2e081afa28913399ce5da31d | |
parent | 302ca6fdb6d0a3ed764a99a3f42829a5a012b258 (diff) |
WIP in merge mode only return free variables
The code typechecks and may well work, but is untested.
-rw-r--r-- | src/AST.hs | 8 | ||||
-rw-r--r-- | src/CHAD.hs | 421 | ||||
-rw-r--r-- | src/Example.hs | 2 |
3 files changed, 368 insertions, 63 deletions
@@ -12,6 +12,7 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE EmptyCase #-} module AST (module AST, module AST.Weaken) where import Data.Functor.Const @@ -76,7 +77,7 @@ deriving instance Show (SScalTy t) type TIx = TScal TI64 -type Idx :: [Ty] -> Ty -> Type +type Idx :: [k] -> k -> Type data Idx env t where IZ :: Idx (t : env) t IS :: Idx env t -> Idx (a : env) t @@ -276,6 +277,11 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list slistMap _ SNil = SNil slistMap f (SCons x list) = SCons (f x) (slistMap f list) +slistIdx :: SList f list -> Idx list t -> f t +slistIdx (SCons x _) IZ = x +slistIdx (SCons _ list) (IS i) = slistIdx list i +slistIdx SNil i = case i of {} + idx2int :: Idx env t -> Int idx2int IZ = 0 idx2int (IS n) = 1 + idx2int n diff --git a/src/CHAD.hs b/src/CHAD.hs index b074470..9e1f038 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -11,10 +11,12 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE PartialTypeSignatures #-} module CHAD where import Data.Bifunctor (first, second) import Data.Kind (Type) +import Data.Proxy import Data.Some import GHC.TypeLits (Symbol) @@ -231,17 +233,26 @@ plusTup env0@(SCons t env) a b = (plus t (ESnd ext (EVar ext (tTup (d2e env0)) (IS IZ))) (ESnd ext (EVar ext (tTup (d2e env0)) IZ))) -data Ret env sto t = - forall env'. - Ret (Bindings Ex (D1E env) env') - (Ex env' (D1 t)) - (Ex (D2 t : env') (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) -deriving instance Show (Ret env sto t) +data Subenv env env' where + SETop :: Subenv '[] '[] + SEYes :: Subenv env env' -> Subenv (t : env) (t : env') + SENo :: Subenv env env' -> Subenv (t : env) env' +deriving instance Show (Subenv env env') + +data Ret env0 sto t = + forall env env0F. + Ret (Bindings Ex (D1E env0) env) + (Ex env (D1 t)) + (Subenv (Select env0 sto "merge") env0F) + (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0F)))) +deriving instance Show (Ret env0 sto t) data RetPair env0 sto env t = + forall env0F. RetPair (Ex env (D1 t)) - (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E (Select env0 sto "merge"))))) - deriving (Show) + (Subenv (Select env0 sto "merge") env0F) + (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0F)))) +deriving instance Show (RetPair env0 sto env t) data Rets env0 sto env list = forall env'. @@ -249,6 +260,39 @@ data Rets env0 sto env list = (SList (RetPair env0 sto env') list) deriving instance Show (Rets env0 sto env list) +subList :: SList f env -> Subenv env env' -> SList f env' +subList SNil SETop = SNil +subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons _ xs) (SENo sub) = subList xs sub + +subenvNone :: SList STy env -> Subenv env '[] +subenvNone SNil = SETop +subenvNone (SCons _ env) = SENo (subenvNone env) + +subenvOnehot :: SList STy env -> Idx env t -> Subenv env '[t] +subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) +subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) +subenvOnehot SNil i = case i of {} + +subenvUnion :: Subenv env env1 -> Subenv env env2 -> (forall env3. Subenv env env3 -> Subenv env3 env1 -> Subenv env3 env2 -> r) -> r +subenvUnion SETop SETop k = k SETop SETop SETop +subenvUnion (SENo sub1) (SENo sub2) k = + subenvUnion sub1 sub2 $ \sub3 s31 s32 -> k (SENo sub3) s31 s32 +subenvUnion (SEYes sub1) (SENo sub2) k = + subenvUnion sub1 sub2 $ \sub3 s31 s32 -> k (SEYes sub3) (SEYes s31) (SENo s32) +subenvUnion (SENo sub1) (SEYes sub2) k = + subenvUnion sub1 sub2 $ \sub3 s31 s32 -> k (SEYes sub3) (SENo s31) (SEYes s32) +subenvUnion (SEYes sub1) (SEYes sub2) k = + subenvUnion sub1 sub2 $ \sub3 s31 s32 -> k (SEYes sub3) (SEYes s31) (SEYes s32) + +expandSubenvZeros :: SList STy env0 -> Subenv env0 env0F -> Ex env (Tup (D2E env0F)) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SETop _ = ENil ext +expandSubenvZeros (SCons t ts) (SEYes sub) e = + ELet ext e $ + let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ + in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) +expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) + -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId -- d1W WSink = WSink @@ -257,7 +301,7 @@ deriving instance Show (Rets env0 sto env list) -- d1W (WThen u w) = WThen (d1W u) (d1W w) weakenRetPair :: env :> env' -> RetPair env0 sto env t -> RetPair env0 sto env' t -weakenRetPair w (RetPair e1 e2) = RetPair (weakenExpr w e1) (weakenExpr (WCopy w) e2) +weakenRetPair w (RetPair e1 sub e2) = RetPair (weakenExpr w e1) sub (weakenExpr (WCopy w) e2) weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds list) = @@ -266,10 +310,11 @@ weakenRets w (Rets binds list) = retConcat :: forall env sto list. SList (Ret env sto) list -> Rets env sto (D1E env) list retConcat SNil = Rets BTop SNil -retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list) +retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p sub d) list) | Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list) = Rets (bconcat b binds) (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) + sub (weakenExpr (WCopy (sinkWithBindings binds)) d)) pairs) @@ -320,11 +365,6 @@ d2op op = case op of STF64 -> float STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -freezeRet :: Ret env sto t - -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value - -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) -freezeRet (Ret e0 e1 e2) d = letBinds e0 $ EPair ext e1 (ELet ext d e2) - type Storage :: Symbol -> Type data Storage s where SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator @@ -344,108 +384,176 @@ select s@SMerge (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +freezeRet :: Descr env sto + -> Ret env sto t + -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value + -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) +freezeRet descr (Ret e0 e1 sub e2) d = + letBinds e0 $ + EPair ext + e1 + (ELet ext d + (EMBind e2 + (EMReturn (d2e (select SAccum descr)) + (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))))) + d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil d2e (SCons t ts) = SCons (d2 t) (d2e ts) -drev :: Descr env sto +drev :: forall env sto t. + Descr env sto -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage) -> Ex env t -> Ret env sto t drev des policy = \case EVar _ t i -> - Ret BTop - (EVar ext (d1 t) (conv1Idx i)) - (case conv2Idx des i of - Left accumI -> - EMBind - (EMOne d2acc accumI (EVar ext (d2 t) IZ)) - (EMReturn d2acc (zeroTup (select SMerge des))) - Right tupI -> - EMReturn d2acc (onehotTup (select SMerge des) tupI (EVar ext (d2 t) IZ))) + case conv2Idx des i of + Left accumI -> + Ret BTop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (select SMerge des)) + (EMOne d2acc accumI (EVar ext (d2 t) IZ)) + + Right tupI -> + Ret BTop + (EVar ext (d1 t) (conv1Idx i)) + (subenvOnehot (select SMerge des) tupI) + (EMReturn d2acc (EPair ext (ENil ext) (EVar ext (d2 t) IZ))) ELet _ rhs body - | Ret rhs0 rhs1 rhs2 <- drev des policy rhs + | Ret rhs0 rhs1 subRHS rhs2 <- drev des policy rhs , Some storage <- policy des (typeOf rhs) - , Ret body0 body1 body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body -> + , Ret body0 body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body -> weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' -> - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (weakenExpr wbody0' body1) - (EMBind - (weakenExpr (WCopy wbody0') $ case storage of SAccum -> EMScope body2 ; SMerge -> body2) - (ELet ext (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 (typeOf rhs))) IZ)) $ - EMBind - (weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2) - (EMReturn d2acc (plusTup (select SMerge des) - (EFst ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 (typeOf rhs))) (IS (IS IZ)))) - (EVar ext (tTup (d2e (select SMerge des))) IZ))))) + case storage of + SAccum -> + subenvUnion subRHS subBody $ \subBoth subRHS2Both subBody2Both -> + let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in + Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + (weakenExpr wbody0' body1) + subBoth + (EMBind + (weakenExpr (WCopy wbody0') (EMScope body2)) + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + EMBind + (weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2) + (EMReturn d2acc (plusTup (subList (select SMerge des) subBoth) + (expandSubenvZeros + (subList (select SMerge des) subBoth) + subBody2Both + (EFst ext (EVar ext bodyResType (IS (IS IZ))))) + (expandSubenvZeros + (subList (select SMerge des) subBoth) + subRHS2Both + (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)))))) + + SMerge -> case subBody of -- is the let-bound variable used in the body? + SENo subBody' -> -- it isn't, so the RHS was dead code? Let's not differentiate through the RHS then, in any case + Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + (weakenExpr wbody0' body1) + subBody' + (weakenExpr (WCopy wbody0') body2) -- we have no cotangent for the RHS, nothing to pass on, so it's just this + SEYes subBody' -> + subenvUnion subRHS subBody' $ \subBoth subRHS2Both subBody2Both -> + let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in + Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + (weakenExpr wbody0' body1) + subBoth + (EMBind + (weakenExpr (WCopy wbody0') body2) + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + EMBind + (weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2) + (EMReturn d2acc (plusTup (subList (select SMerge des) subBoth) + (expandSubenvZeros + (subList (select SMerge des) subBoth) + subBody2Both + (EFst ext (EVar ext bodyResType (IS (IS IZ))))) + (expandSubenvZeros + (subList (select SMerge des) subBoth) + subRHS2Both + (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)))))) EPair _ a b - | Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil) + | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> + subenvUnion subA subB $ \subBoth subA2Both subB2Both -> Ret binds (EPair ext a1 b1) + subBoth (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) - (EMReturn d2acc (zeroTup (select SMerge des))) + (EMReturn d2acc (zeroTup (subList (select SMerge des) subBoth))) (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy (wSinks @[_,_])) a2)) $ EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ))) (weakenExpr (WCopy (wSinks @[_,_,_])) b2)) $ EMReturn d2acc - (plusTup (select SMerge des) - (EVar ext (tTup (d2e (select SMerge des))) (IS IZ)) - (EVar ext (tTup (d2e (select SMerge des))) IZ)))) + (plusTup (subList (select SMerge des) subBoth) + (expandSubenvZeros + (subList (select SMerge des) subBoth) + subA2Both + (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))) + (expandSubenvZeros + (subList (select SMerge des) subBoth) + subB2Both + (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))))) EFst _ e - | Ret e0 e1 e2 <- drev des policy e + | Ret e0 e1 sub e2 <- drev des policy e , STPair t1 t2 <- typeOf e -> Ret e0 (EFst ext e1) + sub (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 e1 e2 <- drev des policy e + | Ret e0 e1 sub e2 <- drev des policy e , STPair t1 t2 <- typeOf e -> Ret e0 (ESnd ext e1) + sub (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop (ENil ext) (EMReturn d2acc (zeroTup (select SMerge des))) + ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (EMReturn d2acc (ENil ext)) EInl _ t2 e - | Ret e0 e1 e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des policy e -> Ret e0 (EInl ext (d1 t2) e1) + sub (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) - (EMReturn d2acc (zeroTup (select SMerge des))) + (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks @[_,_])) e2) - (EError (STEVM d2acc (tTup (d2e (select SMerge des)))) "inl<-dinr"))) + (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inl<-dinr"))) EInr _ t1 e - | Ret e0 e1 e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des policy e -> Ret e0 (EInr ext (d1 t1) e1) + sub (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) - (EMReturn d2acc (zeroTup (select SMerge des))) + (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError (STEVM d2acc (tTup (d2e (select SMerge des)))) "inr<-dinl") + (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl") (weakenExpr (WCopy (wSinks @[_,_])) e2))) ECase _ e a b | STEither t1 t2 <- typeOf e - , Ret e0 e1 e2 <- drev des policy e + , Ret e0 e1 subE e2 <- drev des policy e , Some storageA <- policy des t1 , Some storageB <- policy des t2 - , Ret a0 a1 a2 <- drev (des `DPush` (t1, storageA)) policy a - , Ret b0 b1 b2 <- drev (des `DPush` (t2, storageB)) policy b + , Ret a0 a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a + , Ret b0 b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b , TupBinds tapeA collectA reconA <- tupBinds a0 , TupBinds tapeB collectB reconB <- tupBinds b0 , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) -> weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) a0 $ \a0' wa0' -> weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) b0 $ \b0' wb0' -> + caseOutSubenv des t1 storageA t2 storageB Proxy Proxy subE subA subB $ \subOut subOutE expandA2 expandB2 -> Ret (e0 `BPush` (d1 (typeOf e), e1) `BPush` (tPrimal, @@ -453,6 +561,7 @@ drev des policy = \case (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) (EFst ext (EVar ext tPrimal IZ)) + subOut (EMBind (ECase ext (EVar ext (STEither (d1 t1) (d1 t2)) (IS (IS IZ))) (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) @@ -460,10 +569,10 @@ drev des policy = \case TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ - EMBind (weakenExpr (WCopy wrebinds) $ case storageA of SAccum -> EMScope a2 ; SMerge -> a2) + EMBind (weakenExpr (WCopy wrebinds) $ expandA2 a2) (EMReturn d2acc (EInr ext STNil (EInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 t1)) IZ)))))) + (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subOut))) (d2 t1)) IZ)))))) (EError (STEVM d2acc (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape")) (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) (EError (STEVM d2acc (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape") @@ -471,28 +580,36 @@ drev des policy = \case TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ - EMBind (weakenExpr (WCopy wrebinds) $ case storageB of SAccum -> EMScope b2 ; SMerge -> b2) + EMBind (weakenExpr (WCopy wrebinds) $ expandB2 b2) (EMReturn d2acc (EInr ext STNil (EInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 t2)) IZ)))))))) - (weakenExpr (WCopy (wSinks @[_,_,_])) e2)) + (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subOut))) (d2 t2)) IZ)))))))) + (weakenExpr (WCopy (wSinks @[_,_,_])) $ + EMBind e2 $ + EMReturn d2acc (expandSubenvZeros + (subList (select SMerge des) subOut) + subOutE + (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)))) EConst _ t val -> Ret BTop (EConst ext t val) - (EMReturn d2acc (zeroTup (select SMerge des))) + (subenvNone (select SMerge des)) + (EMReturn d2acc (ENil ext)) EOp _ op e - | Ret e0 e1 e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des policy e -> case d2op op of Linear d2opfun -> Ret e0 (d1op op e1) + sub (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> Ret (e0 `BPush` (d1 (typeOf e), e1)) (d1op op $ EVar ext (d1 (typeOf e)) IZ) + sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) (EVar ext (d2 (opt2 op)) IZ)) (weakenExpr (WCopy (wSinks @[_,_])) e2)) @@ -501,3 +618,185 @@ drev des policy = \case where d2acc = d2e (select SAccum des) + +caseOutSubenv + :: Descr env sto + -> STy a1 -> Storage s1 + -> STy a2 -> Storage s2 + -> Proxy exenv2 -> Proxy exenv3 + -> Subenv (Select env sto "merge") envE + -> Subenv (Select (a1 : env) (s1 : sto) "merge") env1 + -> Subenv (Select (a2 : env) (s2 : sto) "merge") env2 + -> (forall envOut. + Subenv (Select env sto "merge") envOut + -> Subenv envOut envE + -> (Ex exenv2 (TEVM (D2E (Select (a1 : env) (s1 : sto) "accum")) (Tup (D2E env1))) + -> Ex exenv2 (TEVM (D2E (Select env sto "accum")) (TPair (Tup (D2E envOut)) (D2 a1)))) + -> (Ex exenv3 (TEVM (D2E (Select (a2 : env) (s2 : sto) "accum")) (Tup (D2E env2))) + -> Ex exenv3 (TEVM (D2E (Select env sto "accum")) (TPair (Tup (D2E envOut)) (D2 a2)))) + -> r) + -> r +caseOutSubenv des t1 s1 t2 s2 _ _ subE sub1 sub2 k = + case (s1, sub1, s2, sub2) of + (SAccum, _, SAccum, _) -> + subenvUnion sub1 sub2 $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind (EMScope e1) $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub1))) (d2 t1) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (\e2 -> EMBind (EMScope e2) $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub2))) (d2 t2) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (SAccum, _, SMerge, SEYes sub2') -> + subenvUnion sub1 sub2' $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind (EMScope e1) $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub1))) (d2 t1) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (\e2 -> EMBind e2 $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub2'))) (d2 t2) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (SAccum, _, SMerge, SENo sub2') -> + subenvUnion sub1 sub2' $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind (EMScope e1) $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub1))) (d2 t1) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (\e2 -> EMBind e2 $ + EMReturn (d2e (select SAccum des)) $ + let t = tTup (d2e (subList (select SMerge des) sub2')) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EVar ext t IZ) + (zero t2)) + (SMerge, SEYes sub1', SAccum, _) -> + subenvUnion sub1' sub2 $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind e1 $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub1'))) (d2 t1) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (\e2 -> EMBind (EMScope e2) $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub2))) (d2 t2) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (SMerge, SENo sub1', SAccum, _) -> + subenvUnion sub1' sub2 $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind e1 $ + EMReturn (d2e (select SAccum des)) $ + let t = tTup (d2e (subList (select SMerge des) sub1')) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EVar ext t IZ) + (zero t1)) + (\e2 -> EMBind (EMScope e2) $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub2))) (d2 t2) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (SMerge, SEYes sub1', SMerge, SEYes sub2') -> + subenvUnion sub1' sub2' $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind e1 $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub1'))) (d2 t1) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (\e2 -> EMBind e2 $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub2'))) (d2 t2) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (SMerge, SEYes sub1', SMerge, SENo sub2') -> + subenvUnion sub1' sub2' $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind e1 $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub1'))) (d2 t1) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (\e2 -> EMBind e2 $ + EMReturn (d2e (select SAccum des)) $ + let t = tTup (d2e (subList (select SMerge des) sub2')) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EVar ext t IZ) + (zero t2)) + (SMerge, SENo sub1', SMerge, SEYes sub2') -> + subenvUnion sub1' sub2' $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind e1 $ + EMReturn (d2e (select SAccum des)) $ + let t = tTup (d2e (subList (select SMerge des) sub1')) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EVar ext t IZ) + (zero t1)) + (\e2 -> EMBind e2 $ + EMReturn (d2e (select SAccum des)) $ + let t = STPair (tTup (d2e (subList (select SMerge des) sub2'))) (d2 t2) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EFst ext (EVar ext t IZ)) + (ESnd ext (EVar ext t IZ))) + (SMerge, SENo sub1', SMerge, SENo sub2') -> + subenvUnion sub1' sub2' $ \sub3 s31 s32 -> + subenvUnion subE sub3 $ \sub4 s4E s43 -> + k sub4 s4E + (\e1 -> EMBind e1 $ + EMReturn (d2e (select SAccum des)) $ + let t = tTup (d2e (subList (select SMerge des) sub1')) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s31 $ + EVar ext t IZ) + (zero t1)) + (\e2 -> EMBind e2 $ + EMReturn (d2e (select SAccum des)) $ + let t = tTup (d2e (subList (select SMerge des) sub2')) in + EPair ext (expandSubenvZeros (subList (select SMerge des) sub4) s43 $ + expandSubenvZeros (subList (select SMerge des) sub3) s32 $ + EVar ext t IZ) + (zero t2)) diff --git a/src/Example.hs b/src/Example.hs index f3baedf..389248a 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -9,7 +9,7 @@ import CHAD import Simplify --- ppExpr senv5 $ simplifyN 20 $ freezeRet (drev (descr5 SAccum SAccum) (\_ _ -> Some SAccum) ex5) (EConst ext STF32 1.0) +-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SAccum SAccum in freezeRet d (drev d (\_ _ -> Some SAccum) ex5) (EConst ext STF32 1.0) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c |