summaryrefslogtreecommitdiff
path: root/src/CHAD/Types.hs
blob: 44ac20e85b751933062663b103e32d572804c38d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where

import AST.Accum
import AST.Types
import Data


type family D1 t where
  D1 TNil = TNil
  D1 (TPair a b) = TPair (D1 a) (D1 b)
  D1 (TEither a b) = TEither (D1 a) (D1 b)
  D1 (TLEither a b) = TLEither (D1 a) (D1 b)
  D1 (TMaybe a) = TMaybe (D1 a)
  D1 (TArr n t) = TArr n (D1 t)
  D1 (TScal t) = TScal t

type family D2 t where
  D2 TNil = TNil
  D2 (TPair a b) = TPair (D2 a) (D2 b)
  D2 (TEither a b) = TLEither (D2 a) (D2 b)
  D2 (TLEither a b) = TLEither (D2 a) (D2 b)
  D2 (TMaybe t) = TMaybe (D2 t)
  D2 (TArr n t) = TArr n (D2 t)
  D2 (TScal t) = D2s t

type family D2s t where
  D2s TI32 = TNil
  D2s TI64 = TNil
  D2s TF32 = TScal TF32
  D2s TF64 = TScal TF64
  D2s TBool = TNil

type family D1E env where
  D1E '[] = '[]
  D1E (t : env) = D1 t : D1E env

type family D2E env where
  D2E '[] = '[]
  D2E (t : env) = D2 t : D2E env

type family D2AcE env where
  D2AcE '[] = '[]
  D2AcE (t : env) = TAccum (D2 t) : D2AcE env

d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"

d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env

d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
d2M (STPair a b) = SMTPair (d2M a) (d2M b)
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
d2M (STArr n t) = SMTArr n (d2M t)
d2M (STScal t) = case t of
  STI32 -> SMTNil
  STI64 -> SMTNil
  STF32 -> SMTScal STF32
  STF64 -> SMTScal STF64
  STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"

d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M

d2eM :: SList STy env -> SList SMTy (D2E env)
d2eM SNil = SNil
d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts

d2e :: SList STy env -> SList STy (D2E env)
d2e = slistMap fromSMTy . d2eM

d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts


data CHADConfig = CHADConfig
  { -- | D[let] will bind variables containing arrays in accumulator mode.
    chcLetArrayAccum :: Bool
  , -- | D[case] will bind variables containing arrays in accumulator mode.
    chcCaseArrayAccum :: Bool
  , -- | Introduce top-level arguments containing arrays in accumulator mode.
    chcArgArrayAccum :: Bool
  , -- | Place with-blocks around array variable scopes, and redirect accumulations there.
    chcSmartWith :: Bool
  }
  deriving (Show)

defaultConfig :: CHADConfig
defaultConfig = CHADConfig
  { chcLetArrayAccum = False
  , chcCaseArrayAccum = False
  , chcArgArrayAccum = False
  , chcSmartWith = False
  }

chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
                  , chcCaseArrayAccum = True
                  , chcArgArrayAccum = True
                  , chcSmartWith = True }


------------------------------------ LEMMAS ------------------------------------

indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl

lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
lemZeroInfoScal STI32 = Refl
lemZeroInfoScal STI64 = Refl
lemZeroInfoScal STF32 = Refl
lemZeroInfoScal STF64 = Refl
lemZeroInfoScal STBool = Refl

lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
lemDeepZeroInfoScal STI32 = Refl
lemDeepZeroInfoScal STI64 = Refl
lemDeepZeroInfoScal STF32 = Refl
lemDeepZeroInfoScal STF64 = Refl
lemDeepZeroInfoScal STBool = Refl

d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
  STNil -> Refl
  STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
  STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
  STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
  STMaybe t | Refl <- d1Identity t -> Refl
  STArr _ t | Refl <- d1Identity t -> Refl
  STScal _ -> Refl
  STAccum{} -> error "Accumulators not allowed in input program"

d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl