summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Top.hs53
1 files changed, 53 insertions, 0 deletions
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
new file mode 100644
index 0000000..9df5412
--- /dev/null
+++ b/src/CHAD/Top.hs
@@ -0,0 +1,53 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Top where
+
+import AST
+import CHAD
+import CHAD.Types
+import Data
+
+
+type family MergeEnv env where
+ MergeEnv '[] = '[]
+ MergeEnv (t : ts) = "merge" : MergeEnv ts
+
+mergeDescr :: SList STy env -> Descr env (MergeEnv env)
+mergeDescr SNil = DTop
+mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge)
+
+mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
+mergeEnvNoAccum SNil = Refl
+mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl
+
+mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
+mergeEnvOnlyMerge SNil = Refl
+mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+
+chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
+chad env term
+ | Refl <- mergeEnvNoAccum env
+ , Refl <- mergeEnvOnlyMerge env
+ = freezeRet (mergeDescr env) (drev (mergeDescr env) term)
+
+chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+chad' env term
+ | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
+ = chad env term