1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
import AST.Accum
import AST.Types
import Data
type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
D1 (TEither a b) = TEither (D1 a) (D1 b)
D1 (TLEither a b) = TLEither (D1 a) (D1 b)
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TPair (D2 a) (D2 b)
D2 (TEither a b) = TLEither (D2 a) (D2 b)
D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
D2s TI32 = TNil
D2s TI64 = TNil
D2s TF32 = TScal TF32
D2s TF64 = TScal TF64
D2s TBool = TNil
type family D1E env where
D1E '[] = '[]
D1E (t : env) = D1 t : D1E env
type family D2E env where
D2E '[] = '[]
D2E (t : env) = D2 t : D2E env
type family D2AcE env where
D2AcE '[] = '[]
D2AcE (t : env) = TAccum (D2 t) : D2AcE env
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env
d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
d2M (STPair a b) = SMTPair (d2M a) (d2M b)
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
d2M (STArr n t) = SMTArr n (d2M t)
d2M (STScal t) = case t of
STI32 -> SMTNil
STI64 -> SMTNil
STF32 -> SMTScal STF32
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
d2eM :: SList STy env -> SList SMTy (D2E env)
d2eM SNil = SNil
d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts
d2e :: SList STy env -> SList STy (D2E env)
d2e = slistMap fromSMTy . d2eM
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts
data CHADConfig = CHADConfig
{ -- | D[let] will bind variables containing arrays in accumulator mode.
chcLetArrayAccum :: Bool
, -- | D[case] will bind variables containing arrays in accumulator mode.
chcCaseArrayAccum :: Bool
, -- | Introduce top-level arguments containing arrays in accumulator mode.
chcArgArrayAccum :: Bool
, -- | Place with-blocks around array variable scopes, and redirect accumulations there.
chcSmartWith :: Bool
}
deriving (Show)
defaultConfig :: CHADConfig
defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
, chcArgArrayAccum = False
, chcSmartWith = False
}
chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
, chcCaseArrayAccum = True
, chcArgArrayAccum = True
, chcSmartWith = True }
------------------------------------ LEMMAS ------------------------------------
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
lemZeroInfoScal STI32 = Refl
lemZeroInfoScal STI64 = Refl
lemZeroInfoScal STF32 = Refl
lemZeroInfoScal STF64 = Refl
lemZeroInfoScal STBool = Refl
lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
lemDeepZeroInfoScal STI32 = Refl
lemDeepZeroInfoScal STI64 = Refl
lemDeepZeroInfoScal STF32 = Refl
lemDeepZeroInfoScal STF64 = Refl
lemDeepZeroInfoScal STBool = Refl
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl
STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STMaybe t | Refl <- d1Identity t -> Refl
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
|