summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs211
1 files changed, 200 insertions, 11 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index f01ab9e..cd4445e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -7,6 +7,8 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
module CHAD where
import Data.Functor.Const
@@ -18,15 +20,16 @@ type Ex = Expr (Const ())
data Bindings f env env' where
BTop :: Bindings f env env
- BPush :: Bindings f env env' -> f env' t -> Bindings f env (t : env')
+ BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env')
deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
+infixl `BPush`
weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
-> env1 :> env2 -> Bindings f env1 env'
-> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r
weakenBindings _ w BTop k = k BTop w
-weakenBindings wf w (BPush b x) k =
- weakenBindings wf w b $ \b' w' -> k (BPush b' (wf w' x)) (WCopy w')
+weakenBindings wf w (BPush b (t, x)) k =
+ weakenBindings wf w b $ \b' w' -> k (BPush b' (t, wf w' x)) (WCopy w')
sinkWithBindings :: Bindings f env env' -> env :> env'
sinkWithBindings BTop = WId
@@ -34,22 +37,59 @@ sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3
bconcat b1 BTop = b1
-bconcat b1 (BPush b2 x) = BPush (bconcat b1 b2) x
+bconcat b1 (BPush b2 (t, x)) = BPush (bconcat b1 b2) (t, x)
bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
-> Bindings f env env1 -> Bindings f env env2
-> (forall env12. Bindings f env env12 -> r) -> r
bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2')
+bsnoc :: STy t -> f env t -> Bindings f (t : env) env' -> Bindings f env env'
+bsnoc t x b = bconcat (BTop `BPush` (t, x)) b
+
+data TupBindsReconstruct f env1 env2 env3 =
+ forall env4.
+ TupBindsReconstruct (Bindings f env3 env4)
+ (env2 :> env4)
+
+data TupBinds f env1 env2 =
+ forall tape.
+ TupBinds (STy tape)
+ (forall env2'. env2 :> env2' -> Ex env2' tape)
+ (forall env3. env1 :> env3 -> Idx env3 tape -> TupBindsReconstruct f env1 env2 env3)
+
+tupBinds :: Bindings Ex env1 env2 -> TupBinds Ex env1 env2
+tupBinds BTop = TupBinds STNil (\_ -> ENil ext) (\w _ -> TupBindsReconstruct BTop w)
+tupBinds (BPush binds (t, _))
+ | TupBinds tape collect recon <- tupBinds binds
+ = TupBinds (STPair tape t)
+ (\w -> EPair ext (collect (w .> WSink))
+ (EVar ext t (w @> IZ)))
+ (\w tapeidx ->
+ case recon (WSink .> w) IZ of
+ TupBindsReconstruct rebinds wunder ->
+ let rebinds1 = bsnoc tape (EFst ext (EVar ext (STPair tape t) tapeidx)) rebinds
+ in TupBindsReconstruct
+ (rebinds1 `BPush`
+ (t, ESnd ext (EVar ext (STPair tape t)
+ (sinkWithBindings rebinds1 @> tapeidx))))
+ (WCopy wunder))
+
+letBinds :: Bindings Ex env env' -> Ex env' t -> Ex env t
+letBinds BTop = id
+letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
+ D1 (TEither a b) = TEither (D1 a) (D1 b)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TPair (D2 a) (D2 b)
+ D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
+ D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
-- D2 (TArr n t) = _
D2 (TScal t) = D2s t
@@ -71,13 +111,15 @@ type family D2E env where
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
+d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STEVM{} = error "EVM not allowed in input program"
d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
-d2 (STPair a b) = STPair (d2 a) (d2 b)
+d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
+d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
d2 STArr{} = error "TODO arrays"
d2 (STScal t) = case t of
STI32 -> STNil
@@ -87,6 +129,10 @@ d2 (STScal t) = case t of
STBool -> STNil
d2 STEVM{} = error "EVM not allowed in input program"
+d2e :: SList STy list -> SList STy (D2E list)
+d2e SNil = SNil
+d2e (SCons t list) = SCons (d2 t) (d2e list)
+
d2list :: SList STy env -> SList STy (D2E env)
d2list SNil = SNil
d2list (SCons x l) = SCons (d2 x) (d2list l)
@@ -99,6 +145,19 @@ conv2Idx :: Idx env t -> Idx (D2E env) (D2 t)
conv2Idx IZ = IZ
conv2Idx (IS i) = IS (conv2Idx i)
+zero :: STy t -> Ex env (D2 t)
+zero STNil = ENil ext
+zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
+zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
+zero STArr{} = error "TODO arrays"
+zero (STScal t) = case t of
+ STI32 -> ENil ext
+ STI64 -> ENil ext
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+ STBool -> ENil ext
+zero STEVM{} = error "EVM not allowed in input program"
+
data Ret env t =
forall env'.
Ret (Bindings Ex (D1E env) env')
@@ -106,21 +165,151 @@ data Ret env t =
(Ex (D2 t : env') (TEVM (D2E env) TNil))
deriving instance Show (Ret env t)
+data RetPair env0 env t =
+ RetPair (Ex env (D1 t))
+ (Ex (D2 t : env) (TEVM (D2E env0) TNil))
+ deriving (Show)
+
+data Rets env0 env list =
+ forall env'.
+ Rets (Bindings Ex env env')
+ (SList (RetPair env0 env') list)
+deriving instance Show (Rets env0 env list)
+
+-- d1W :: env :> env' -> D1E env :> D1E env'
+-- d1W WId = WId
+-- d1W WSink = WSink
+-- d1W (WCopy w) = WCopy (d1W w)
+-- d1W (WPop w) = WPop (d1W w)
+-- d1W (WThen u w) = WThen (d1W u) (d1W w)
+
+weakenRetPair :: env :> env' -> RetPair env0 env t -> RetPair env0 env' t
+weakenRetPair w (RetPair e1 e2) = RetPair (weakenExpr w e1) (weakenExpr (WCopy w) e2)
+
+weakenRets :: env :> env' -> Rets env0 env list -> Rets env0 env' list
+weakenRets w (Rets binds list) =
+ weakenBindings weakenExpr w binds $ \binds' wbinds' ->
+ Rets binds' (slistMap (weakenRetPair wbinds') list)
+
+retConcat :: forall env list. SList (Ret env) list -> Rets env (D1E env) list
+retConcat SNil = Rets BTop SNil
+retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list)
+ | Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list)
+ = Rets (bconcat b binds)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
+ (weakenExpr (WCopy (sinkWithBindings binds)) d))
+ pairs)
+
drev :: SList STy env -> Ex env t -> Ret env t
drev senv = \case
EVar _ t i ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(EMOne (d2list senv) (conv2Idx i) (EVar ext (d2 t) IZ))
+
ELet _ rhs body
| Ret rhs0 rhs1 rhs2 <- drev senv rhs
, Ret body0 body1 body2 <- drev (SCons (typeOf rhs) senv) body ->
weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' ->
- Ret (bconcat (BPush rhs0 rhs1) body0')
+ Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(weakenExpr wbody0' body1)
(EMBind (EMScope (weakenExpr (WCopy wbody0') body2))
- (ELet ext (ESnd ext STNil (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $
- weakenExpr (WCopy (WSink .> WSink .> WPop (sinkWithBindings body0'))) rhs2))
+ (ELet ext (ESnd ext (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $
+ weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2))
+
+ EPair _ a b
+ | Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil)
+ <- retConcat $ drev senv a `SCons` drev senv b `SCons` SNil
+ , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
+ Ret binds
+ (EPair ext a1 b1)
+ (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
+ (EMReturn (d2e senv) (ENil ext))
+ (EMBind (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy (wSinks @[_,_])) a2))
+ (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (wSinks @[_,_,_])) b2))))
+
+ EFst _ e
+ | Ret e0 e1 e2 <- drev senv e
+ , STPair t1 t2 <- typeOf e ->
+ Ret e0
+ (EFst ext e1)
+ (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
+ weakenExpr (WCopy WSink) e2)
+
+ ESnd _ e
+ | Ret e0 e1 e2 <- drev senv e
+ , STPair t1 t2 <- typeOf e ->
+ Ret e0
+ (ESnd ext e1)
+ (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
+ weakenExpr (WCopy WSink) e2)
+
+ ENil _ -> Ret BTop (ENil ext) (EMReturn (d2e senv) (ENil ext))
+
+ EInl _ t2 e
+ | Ret e0 e1 e2 <- drev senv e ->
+ Ret e0
+ (EInl ext (d1 t2) e1)
+ (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
+ (EMReturn (d2e senv) (ENil ext))
+ (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
+ (weakenExpr (WCopy (wSinks @[_,_])) e2)
+ (EError (STEVM (d2e senv) STNil) "inl<-dinr")))
+
+ EInr _ t1 e
+ | Ret e0 e1 e2 <- drev senv e ->
+ Ret e0
+ (EInr ext (d1 t1) e1)
+ (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
+ (EMReturn (d2e senv) (ENil ext))
+ (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
+ (EError (STEVM (d2e senv) STNil) "inr<-dinl")
+ (weakenExpr (WCopy (wSinks @[_,_])) e2)))
+
+ ECase _ e a b
+ | STEither t1 t2 <- typeOf e
+ , Ret e0 e1 e2 <- drev senv e
+ , Ret a0 a1 a2 <- drev (SCons t1 senv) a
+ , Ret b0 b1 b2 <- drev (SCons t2 senv) b
+ , TupBinds tapeA collectA reconA <- tupBinds a0
+ , TupBinds tapeB collectB reconB <- tupBinds b0
+ , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) ->
+ weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) a0 $ \a0' wa0' ->
+ weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) b0 $ \b0' wb0' ->
+ Ret (e0 `BPush`
+ (d1 (typeOf e), e1) `BPush`
+ (tPrimal,
+ ECase ext (EVar ext (d1 (typeOf e)) IZ)
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
+ (EFst ext (EVar ext tPrimal IZ))
+ (EMBind
+ (ECase ext (EVar ext (STEither (d1 t1) (d1 t2)) (IS (IS IZ)))
+ (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
+ (case reconA (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of
+ TupBindsReconstruct rebinds wrebinds ->
+ letBinds rebinds $
+ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
+ EMBind (weakenExpr (WCopy wrebinds) (EMScope a2))
+ (EMReturn (d2e senv)
+ (EInr ext STNil (EInl ext (d2 t2)
+ (ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ))))))
+ (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
+ (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
+ (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
+ (case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of
+ TupBindsReconstruct rebinds wrebinds ->
+ letBinds rebinds $
+ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
+ EMBind (weakenExpr (WCopy wrebinds) (EMScope b2))
+ (EMReturn (d2e senv)
+ (EInr ext STNil (EInr ext (d2 t1)
+ (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ))))))))
+ (weakenExpr (WCopy (wSinks @[_,_,_])) e2))
+
_ -> undefined
- where
- ext = Const ()
+
+ext :: Const () a
+ext = Const ()