1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
import Analysis.Identity
import AST
import AST.Env
import AST.Sparse
import AST.SplitLets
import AST.Weaken.Auto
import CHAD
import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
import qualified Data.VarMap as VarMap
type family MergeEnv env where
MergeEnv '[] = '[]
MergeEnv (t : ts) = "merge" : MergeEnv ts
mergeDescr :: SList STy env -> Descr env (MergeEnv env)
mergeDescr SNil = DTop
mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge)
mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
mergeEnvNoAccum SNil = Refl
mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl
mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
mergeEnvOnlyMerge SNil = Refl
mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
accumDescr SNil k = k DTop
accumDescr (t `SCons` env) k = accumDescr env $ \des ->
if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
else k (des `DPush` (t, Nothing, SMerge))
reassembleD2E :: Descr env sto
-> D1E env :> env'
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
reassembleD2E DTop _ _ = ENil ext
reassembleD2E (des `DPush` (_, _, SAccum)) w e =
eunPair e $ \w1 e1 e2 ->
eunPair e1 $ \w2 e11 e12 ->
EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
reassembleD2E (des `DPush` (_, _, SMerge)) w e =
eunPair e $ \w1 e1 e2 ->
eunPair e2 $ \w2 e21 e22 ->
EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
EPair ext (reassembleD2E des (WPop w) e)
(EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
| True <- chcArgArrayAccum config
= let ?config = config
in accumDescr env $ \descr ->
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
(reassembleD2E descr (WSink .> WSink)
(EPair ext (ESnd ext (EVar ext tvar IZ))
(ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
= let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
where
term' = identityAnalysis env (splitLets term)
chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' config env term
| Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
= chad config env term
|