{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD where import Data.Functor.Const import AST type Ex = Expr (Const ()) data Bindings f env env' where BTop :: Bindings f env env BPush :: Bindings f env env' -> f env' t -> Bindings f env (t : env') deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -> env1 :> env2 -> Bindings f env1 env' -> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r weakenBindings _ w BTop k = k BTop w weakenBindings wf w (BPush b x) k = weakenBindings wf w b $ \b' w' -> k (BPush b' (wf w' x)) (WCopy w') sinkWithBindings :: Bindings f env env' -> env :> env' sinkWithBindings BTop = WId sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3 bconcat b1 BTop = b1 bconcat b1 (BPush b2 x) = BPush (bconcat b1 b2) x bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -> Bindings f env env1 -> Bindings f env env2 -> (forall env12. Bindings f env env12 -> r) -> r bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t type family D2 t where D2 TNil = TNil D2 (TPair a b) = TPair (D2 a) (D2 b) -- D2 (TArr n t) = _ D2 (TScal t) = D2s t type family D2s t where D2s TI32 = TNil D2s TI64 = TNil D2s TF32 = TScal TF32 D2s TF64 = TScal TF64 D2s TBool = TNil type family D1E env where D1E '[] = '[] D1E (t : env) = D1 t : D1E env type family D2E env where D2E '[] = '[] D2E (t : env) = D2 t : D2E env d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STEVM{} = error "EVM not allowed in input program" d2 :: STy t -> STy (D2 t) d2 STNil = STNil d2 (STPair a b) = STPair (d2 a) (d2 b) d2 STArr{} = error "TODO arrays" d2 (STScal t) = case t of STI32 -> STNil STI64 -> STNil STF32 -> STScal STF32 STF64 -> STScal STF64 STBool -> STNil d2 STEVM{} = error "EVM not allowed in input program" d2list :: SList STy env -> SList STy (D2E env) d2list SNil = SNil d2list (SCons x l) = SCons (d2 x) (d2list l) conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) conv2Idx :: Idx env t -> Idx (D2E env) (D2 t) conv2Idx IZ = IZ conv2Idx (IS i) = IS (conv2Idx i) data Ret env t = forall env'. Ret (Bindings Ex (D1E env) env') (Ex env' (D1 t)) (Ex (D2 t : env') (TEVM (D2E env) TNil)) deriving instance Show (Ret env t) drev :: SList STy env -> Ex env t -> Ret env t drev senv = \case EVar _ t i -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (EMOne (d2list senv) (conv2Idx i) (EVar ext (d2 t) IZ)) ELet _ rhs body | Ret rhs0 rhs1 rhs2 <- drev senv rhs , Ret body0 body1 body2 <- drev (SCons (typeOf rhs) senv) body -> weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' -> Ret (bconcat (BPush rhs0 rhs1) body0') (weakenExpr wbody0' body1) (EMBind (EMScope (weakenExpr (WCopy wbody0') body2)) (ELet ext (ESnd ext STNil (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WPop (sinkWithBindings body0'))) rhs2)) _ -> undefined where ext = Const ()