1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
import Analysis.Identity
import AST
import AST.Weaken.Auto
import CHAD
import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
import qualified Data.VarMap as VarMap
type family MergeEnv env where
MergeEnv '[] = '[]
MergeEnv (t : ts) = "merge" : MergeEnv ts
mergeDescr :: SList STy env -> Descr env (MergeEnv env)
mergeDescr SNil = DTop
mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge)
mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
mergeEnvNoAccum SNil = Refl
mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl
mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
mergeEnvOnlyMerge SNil = Refl
mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
accumDescr SNil k = k DTop
accumDescr (t `SCons` env) k = accumDescr env $ \des ->
if hasArrays t then k (des `DPush` (t, SAccum))
else k (des `DPush` (t, SMerge))
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl
STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STMaybe t | Refl <- d1Identity t -> Refl
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
reassembleD2E :: Descr env sto
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
reassembleD2E DTop _ = ENil ext
reassembleD2E (des `DPush` (_, SAccum)) e =
ELet ext e $
EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
(ESnd ext (EVar ext (typeOf e) IZ))))
(ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
reassembleD2E (des `DPush` (_, SMerge)) e =
ELet ext e $
EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
(EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
(ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t)
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
| True <- chcArgArrayAccum config
= let ?config = config
in accumDescr env $ \descr ->
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
makeAccumulators (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
freezeRet descr (drev descr VarMap.empty (identityAnalysis env term))) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
(reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
(ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
= let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (identityAnalysis env term))
chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' config env term
| Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
= chad config env term
|