summaryrefslogtreecommitdiff
path: root/src/CHAD/Top.hs
blob: 9df541231a1035bd7f3c2c8d14979c09c8dc9373 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where

import AST
import CHAD
import CHAD.Types
import Data


type family MergeEnv env where
  MergeEnv '[] = '[]
  MergeEnv (t : ts) = "merge" : MergeEnv ts

mergeDescr :: SList STy env -> Descr env (MergeEnv env)
mergeDescr SNil = DTop
mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge)

mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
mergeEnvNoAccum SNil = Refl
mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl

mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
mergeEnvOnlyMerge SNil = Refl
mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl

d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
  STNil -> Refl
  STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
  STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
  STMaybe t | Refl <- d1Identity t -> Refl
  STArr _ t | Refl <- d1Identity t -> Refl
  STScal _ -> Refl
  STAccum{} -> error "Accumulators not allowed in input program"

d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl

chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad env term
  | Refl <- mergeEnvNoAccum env
  , Refl <- mergeEnvOnlyMerge env
  = freezeRet (mergeDescr env) (drev (mergeDescr env) term)

chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' env term
  | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
  = chad env term