diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 211 |
1 files changed, 200 insertions, 11 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index f01ab9e..cd4445e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -7,6 +7,8 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module CHAD where import Data.Functor.Const @@ -18,15 +20,16 @@ type Ex = Expr (Const ()) data Bindings f env env' where BTop :: Bindings f env env - BPush :: Bindings f env env' -> f env' t -> Bindings f env (t : env') + BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env') deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') +infixl `BPush` weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -> env1 :> env2 -> Bindings f env1 env' -> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r weakenBindings _ w BTop k = k BTop w -weakenBindings wf w (BPush b x) k = - weakenBindings wf w b $ \b' w' -> k (BPush b' (wf w' x)) (WCopy w') +weakenBindings wf w (BPush b (t, x)) k = + weakenBindings wf w b $ \b' w' -> k (BPush b' (t, wf w' x)) (WCopy w') sinkWithBindings :: Bindings f env env' -> env :> env' sinkWithBindings BTop = WId @@ -34,22 +37,59 @@ sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3 bconcat b1 BTop = b1 -bconcat b1 (BPush b2 x) = BPush (bconcat b1 b2) x +bconcat b1 (BPush b2 (t, x)) = BPush (bconcat b1 b2) (t, x) bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -> Bindings f env env1 -> Bindings f env env2 -> (forall env12. Bindings f env env12 -> r) -> r bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') +bsnoc :: STy t -> f env t -> Bindings f (t : env) env' -> Bindings f env env' +bsnoc t x b = bconcat (BTop `BPush` (t, x)) b + +data TupBindsReconstruct f env1 env2 env3 = + forall env4. + TupBindsReconstruct (Bindings f env3 env4) + (env2 :> env4) + +data TupBinds f env1 env2 = + forall tape. + TupBinds (STy tape) + (forall env2'. env2 :> env2' -> Ex env2' tape) + (forall env3. env1 :> env3 -> Idx env3 tape -> TupBindsReconstruct f env1 env2 env3) + +tupBinds :: Bindings Ex env1 env2 -> TupBinds Ex env1 env2 +tupBinds BTop = TupBinds STNil (\_ -> ENil ext) (\w _ -> TupBindsReconstruct BTop w) +tupBinds (BPush binds (t, _)) + | TupBinds tape collect recon <- tupBinds binds + = TupBinds (STPair tape t) + (\w -> EPair ext (collect (w .> WSink)) + (EVar ext t (w @> IZ))) + (\w tapeidx -> + case recon (WSink .> w) IZ of + TupBindsReconstruct rebinds wunder -> + let rebinds1 = bsnoc tape (EFst ext (EVar ext (STPair tape t) tapeidx)) rebinds + in TupBindsReconstruct + (rebinds1 `BPush` + (t, ESnd ext (EVar ext (STPair tape t) + (sinkWithBindings rebinds1 @> tapeidx)))) + (WCopy wunder)) + +letBinds :: Bindings Ex env env' -> Ex env' t -> Ex env t +letBinds BTop = id +letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TEither a b) = TEither (D1 a) (D1 b) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) + D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) -- D2 (TArr n t) = _ D2 (TScal t) = D2s t @@ -71,13 +111,15 @@ type family D2E env where d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STEVM{} = error "EVM not allowed in input program" d2 :: STy t -> STy (D2 t) d2 STNil = STNil -d2 (STPair a b) = STPair (d2 a) (d2 b) +d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) +d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) d2 STArr{} = error "TODO arrays" d2 (STScal t) = case t of STI32 -> STNil @@ -87,6 +129,10 @@ d2 (STScal t) = case t of STBool -> STNil d2 STEVM{} = error "EVM not allowed in input program" +d2e :: SList STy list -> SList STy (D2E list) +d2e SNil = SNil +d2e (SCons t list) = SCons (d2 t) (d2e list) + d2list :: SList STy env -> SList STy (D2E env) d2list SNil = SNil d2list (SCons x l) = SCons (d2 x) (d2list l) @@ -99,6 +145,19 @@ conv2Idx :: Idx env t -> Idx (D2E env) (D2 t) conv2Idx IZ = IZ conv2Idx (IS i) = IS (conv2Idx i) +zero :: STy t -> Ex env (D2 t) +zero STNil = ENil ext +zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) +zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) +zero STArr{} = error "TODO arrays" +zero (STScal t) = case t of + STI32 -> ENil ext + STI64 -> ENil ext + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + STBool -> ENil ext +zero STEVM{} = error "EVM not allowed in input program" + data Ret env t = forall env'. Ret (Bindings Ex (D1E env) env') @@ -106,21 +165,151 @@ data Ret env t = (Ex (D2 t : env') (TEVM (D2E env) TNil)) deriving instance Show (Ret env t) +data RetPair env0 env t = + RetPair (Ex env (D1 t)) + (Ex (D2 t : env) (TEVM (D2E env0) TNil)) + deriving (Show) + +data Rets env0 env list = + forall env'. + Rets (Bindings Ex env env') + (SList (RetPair env0 env') list) +deriving instance Show (Rets env0 env list) + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) + +weakenRetPair :: env :> env' -> RetPair env0 env t -> RetPair env0 env' t +weakenRetPair w (RetPair e1 e2) = RetPair (weakenExpr w e1) (weakenExpr (WCopy w) e2) + +weakenRets :: env :> env' -> Rets env0 env list -> Rets env0 env' list +weakenRets w (Rets binds list) = + weakenBindings weakenExpr w binds $ \binds' wbinds' -> + Rets binds' (slistMap (weakenRetPair wbinds') list) + +retConcat :: forall env list. SList (Ret env) list -> Rets env (D1E env) list +retConcat SNil = Rets BTop SNil +retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list) + | Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list) + = Rets (bconcat b binds) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) + (weakenExpr (WCopy (sinkWithBindings binds)) d)) + pairs) + drev :: SList STy env -> Ex env t -> Ret env t drev senv = \case EVar _ t i -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (EMOne (d2list senv) (conv2Idx i) (EVar ext (d2 t) IZ)) + ELet _ rhs body | Ret rhs0 rhs1 rhs2 <- drev senv rhs , Ret body0 body1 body2 <- drev (SCons (typeOf rhs) senv) body -> weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' -> - Ret (bconcat (BPush rhs0 rhs1) body0') + Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (weakenExpr wbody0' body1) (EMBind (EMScope (weakenExpr (WCopy wbody0') body2)) - (ELet ext (ESnd ext STNil (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WPop (sinkWithBindings body0'))) rhs2)) + (ELet ext (ESnd ext (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $ + weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2)) + + EPair _ a b + | Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil) + <- retConcat $ drev senv a `SCons` drev senv b `SCons` SNil + , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> + Ret binds + (EPair ext a1 b1) + (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) + (EMReturn (d2e senv) (ENil ext)) + (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy (wSinks @[_,_])) a2)) + (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (wSinks @[_,_,_])) b2)))) + + EFst _ e + | Ret e0 e1 e2 <- drev senv e + , STPair t1 t2 <- typeOf e -> + Ret e0 + (EFst ext e1) + (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ + weakenExpr (WCopy WSink) e2) + + ESnd _ e + | Ret e0 e1 e2 <- drev senv e + , STPair t1 t2 <- typeOf e -> + Ret e0 + (ESnd ext e1) + (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ + weakenExpr (WCopy WSink) e2) + + ENil _ -> Ret BTop (ENil ext) (EMReturn (d2e senv) (ENil ext)) + + EInl _ t2 e + | Ret e0 e1 e2 <- drev senv e -> + Ret e0 + (EInl ext (d1 t2) e1) + (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) + (EMReturn (d2e senv) (ENil ext)) + (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) + (weakenExpr (WCopy (wSinks @[_,_])) e2) + (EError (STEVM (d2e senv) STNil) "inl<-dinr"))) + + EInr _ t1 e + | Ret e0 e1 e2 <- drev senv e -> + Ret e0 + (EInr ext (d1 t1) e1) + (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) + (EMReturn (d2e senv) (ENil ext)) + (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) + (EError (STEVM (d2e senv) STNil) "inr<-dinl") + (weakenExpr (WCopy (wSinks @[_,_])) e2))) + + ECase _ e a b + | STEither t1 t2 <- typeOf e + , Ret e0 e1 e2 <- drev senv e + , Ret a0 a1 a2 <- drev (SCons t1 senv) a + , Ret b0 b1 b2 <- drev (SCons t2 senv) b + , TupBinds tapeA collectA reconA <- tupBinds a0 + , TupBinds tapeB collectB reconB <- tupBinds b0 + , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) -> + weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) a0 $ \a0' wa0' -> + weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) b0 $ \b0' wb0' -> + Ret (e0 `BPush` + (d1 (typeOf e), e1) `BPush` + (tPrimal, + ECase ext (EVar ext (d1 (typeOf e)) IZ) + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) + (EFst ext (EVar ext tPrimal IZ)) + (EMBind + (ECase ext (EVar ext (STEither (d1 t1) (d1 t2)) (IS (IS IZ))) + (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) + (case reconA (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of + TupBindsReconstruct rebinds wrebinds -> + letBinds rebinds $ + ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ + EMBind (weakenExpr (WCopy wrebinds) (EMScope a2)) + (EMReturn (d2e senv) + (EInr ext STNil (EInl ext (d2 t2) + (ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ)))))) + (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape")) + (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) + (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape") + (case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of + TupBindsReconstruct rebinds wrebinds -> + letBinds rebinds $ + ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ + EMBind (weakenExpr (WCopy wrebinds) (EMScope b2)) + (EMReturn (d2e senv) + (EInr ext STNil (EInr ext (d2 t1) + (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ)))))))) + (weakenExpr (WCopy (wSinks @[_,_,_])) e2)) + _ -> undefined - where - ext = Const () + +ext :: Const () a +ext = Const () |