summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs121
1 files changed, 121 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
new file mode 100644
index 0000000..17ee12b
--- /dev/null
+++ b/src/CHAD.hs
@@ -0,0 +1,121 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD where
+
+import Data.Functor.Const
+
+import AST
+
+
+type Ex = Expr (Const ())
+
+data Bindings f env env' where
+ BTop :: Bindings f env env
+ BPush :: Bindings f env env' -> f env' t -> Bindings f env (t : env')
+deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
+
+weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+ -> env1 :> env2 -> Bindings f env1 env'
+ -> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r
+weakenBindings _ w BTop k = k BTop w
+weakenBindings wf w (BPush b x) k =
+ weakenBindings wf w b $ \b' w' -> k (BPush b' (wf w' x)) (WCopy w')
+
+sinkWithBindings :: Bindings f env env' -> env :> env'
+sinkWithBindings BTop = WId
+sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
+
+bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3
+bconcat b1 BTop = b1
+bconcat b1 (BPush b2 x) = BPush (bconcat b1 b2) x
+
+bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+ -> Bindings f env env1 -> Bindings f env env2 -> (forall env12. Bindings env env12 -> r) -> r
+bconcat' wf b1 b2 = weakenBindings
+-- bconcat :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+-- -> Bindings f env env1 -> Bindings f env env2 -> env :> env'
+-- -> (forall env'12. Bindings f env' env'12 -> r) -> r
+-- bconcat wf BTop b w k = weakenBindings wf w b $ \b' _ -> k b'
+-- bconcat wf (BPush b x) b2 w k =
+-- bconcat wf
+
+type family D1 t where
+ D1 TNil = TNil
+ D1 (TPair a b) = TPair (D1 a) (D1 b)
+ D1 (TArr n t) = TArr n (D1 t)
+ D1 (TScal t) = TScal t
+
+type family D2 t where
+ D2 TNil = TNil
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
+ -- D2 (TArr n t) = _
+ D2 (TScal t) = D2s t
+
+type family D2s t where
+ D2s TI32 = TNil
+ D2s TI64 = TNil
+ D2s TF32 = TScal TF32
+ D2s TF64 = TScal TF64
+ D2s TBool = TNil
+
+type family D1E env where
+ D1E '[] = '[]
+ D1E (t : env) = D1 t : D1E env
+
+type family D2E env where
+ D2E '[] = '[]
+ D2E (t : env) = D2 t : D2E env
+
+data Ret env t =
+ forall env'.
+ Ret (Bindings Ex (D1E env) env')
+ (Ex env' (D1 t))
+ (Ex (D2 t : env') (TEVM (D2E env) TNil))
+deriving instance Show (Ret env t)
+
+d1 :: STy t -> STy (D1 t)
+d1 STNil = STNil
+d1 (STPair a b) = STPair (d1 a) (d1 b)
+d1 (STArr n t) = STArr n (d1 t)
+d1 (STScal t) = STScal t
+d1 STEVM{} = error "EVM not allowed in input program"
+
+d2 :: STy t -> STy (D2 t)
+d2 STNil = STNil
+d2 (STPair a b) = STPair (d2 a) (d2 b)
+d2 STArr{} = error "TODO arrays"
+d2 (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+d2 STEVM{} = error "EVM not allowed in input program"
+
+conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
+conv1Idx IZ = IZ
+conv1Idx (IS i) = IS (conv1Idx i)
+
+conv2Idx :: Idx env t -> Idx (D2E env) (D2 t)
+conv2Idx IZ = IZ
+conv2Idx (IS i) = IS (conv2Idx i)
+
+drev :: Ex env t -> Ret env t
+drev = \case
+ EVar _ t i ->
+ Ret BTop
+ (EVar ext (d1 t) (conv1Idx i))
+ (EMOne (conv2Idx i) (EVar ext (d2 t) IZ))
+ ELet _ rhs body ->
+ Ret _
+ _
+ _
+ where
+ ext = Const ()