blob: e8ec0c98daad97afca75ee5da9c9d0fb58ff03c6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
import AST.Types
import Data
type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
D1 (TEither a b) = TEither (D1 a) (D1 b)
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
D2s TI32 = TNil
D2s TI64 = TNil
D2s TF32 = TScal TF32
D2s TF64 = TScal TF64
D2s TBool = TNil
type family D1E env where
D1E '[] = '[]
D1E (t : env) = D1 t : D1E env
type family D2E env where
D2E '[] = '[]
D2E (t : env) = D2 t : D2E env
type family D2AcE env where
D2AcE '[] = '[]
D2AcE (t : env) = TAccum t : D2AcE env
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env
d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
d2 (STMaybe t) = STMaybe (d2 t)
d2 (STArr n t) = STArr n (d2 t)
d2 (STScal t) = case t of
STI32 -> STNil
STI64 -> STNil
STF32 -> STScal STF32
STF64 -> STScal STF64
STBool -> STNil
d2 STAccum{} = error "Accumulators not allowed in input program"
d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
d2e (t `SCons` ts) = d2 t `SCons` d2e ts
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts
data CHADConfig = CHADConfig
{ -- | D[let] will bind variables containing arrays in accumulator mode.
chcLetArrayAccum :: Bool
, -- | D[case] will bind variables containing arrays in accumulator mode.
chcCaseArrayAccum :: Bool
, -- | Introduce top-level arguments containing arrays in accumulator mode.
chcArgArrayAccum :: Bool
}
deriving (Show)
defaultConfig :: CHADConfig
defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
, chcArgArrayAccum = False
}
chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
, chcCaseArrayAccum = True
, chcArgArrayAccum = True }
------------------------------------ LEMMAS ------------------------------------
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
|