1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Language (
fromNamed,
NExpr,
Ex,
module Language,
module AST.Types,
module Data,
Lookup,
) where
import Array
import AST
import AST.Types
import CHAD.Types
import Data
import Language.AST
data a :-> b = a :-> b
deriving (Show)
infixr 0 :->
body :: NExpr env t -> NFun env env t
body = NBody
lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam
inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
inline = inlineNFun
-- To be used to construct the argument list for 'inline'.
--
-- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y
-- > in inline fun (SNil .$ 16 .$ 26)
(.$) :: SList f list -> f a -> SList f (a : list)
(.$) = flip SCons
let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet
pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
pair = NEPair
fst_ :: NExpr env (TPair a b) -> NExpr env a
fst_ = NEFst
snd_ :: NExpr env (TPair a b) -> NExpr env b
snd_ = NESnd
nil :: NExpr env TNil
nil = NENil
inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b)
inl = NEInl knownTy
inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b)
inr = NEInr knownTy
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConstArr knownNat ty x
build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
build2 :: NExpr env TIx -> NExpr env TIx
-> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
-> NExpr env (TArr (S (S Z)) t)
build2 a1 a2 (v1 :-> v2 :-> b) =
NEBuild (SS (SS SZ))
(pair (pair nil a1) a2)
#idx
(let_ v1 (snd_ (fst_ #idx)) $
let_ v2 (NEDrop SZ (snd_ #idx)) $
NEDrop (SS (SS SZ)) b)
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
map_ :: forall n a b env name. (KnownNat n, KnownTy a)
=> (Var name a :-> NExpr ('(name, a) : env) b)
-> NExpr env (TArr n a) -> NExpr env (TArr n b)
map_ (v :-> a) b
| Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) =
let_ #arg b $
build knownNat (shape #arg) $ #i :->
let_ v (#arg ! #i) $
NEDrop (SS SZ) (NEDrop (SS SZ) a)
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3
sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e
unit :: NExpr env t -> NExpr env (TArr Z t)
unit = NEUnit
replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
replicate1i n a = NEReplicate1Inner n a
maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
maximum1i e = NEMaximum1Inner e
minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
minimum1i e = NEMinimum1Inner e
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConst ty x
idx0 :: NExpr env (TArr Z t) -> NExpr env t
idx0 = NEIdx0
-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
-- (.!) = NEIdx1
-- infixl 9 .!
(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx
infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t
oper2 op a b = NEOp op (pair a b)
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
-> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
-> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
-> NExpr env a -> NExpr env b
-> NExpr env t
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==
(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .< b = oper (OLt knownScalTy) (pair a b)
infix 4 .<
(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>) = flip (.<)
infix 4 .>
(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .<= b = oper (OLe knownScalTy) (pair a b)
infix 4 .<=
(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>=) = flip (.<=)
infix 4 .>=
not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot
and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
and_ = oper2 OAnd
or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
or_ = oper2 OOr
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64)
round_ = oper ORound64
toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64)
toFloat_ = oper OToFl64
idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t)
idiv = oper2 (OIDiv knownScalTy)
|