1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Language (
fromNamed,
NExpr,
module Language,
Lookup,
) where
import Array
import AST
import CHAD.Types
import Data
import Language.AST
data a :-> b = a :-> b
deriving (Show)
infixr 0 :->
body :: NExpr env t -> NFun env env t
body = NBody
lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam
inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
inline = inlineNFun
(.$) :: SList f list -> f a -> SList f (a : list)
(.$) = flip SCons
let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet
pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
pair = NEPair
fst_ :: NExpr env (TPair a b) -> NExpr env a
fst_ = NEFst
snd_ :: NExpr env (TPair a b) -> NExpr env b
snd_ = NESnd
nil :: NExpr env TNil
nil = NENil
inl :: STy b -> NExpr env a -> NExpr env (TEither a b)
inl = NEInl
inr :: STy a -> NExpr env b -> NExpr env (TEither a b)
inr = NEInr
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConstArr knownNat ty x
build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
build2 :: NExpr env TIx -> NExpr env TIx
-> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
-> NExpr env (TArr (S (S Z)) t)
build2 a1 a2 (v1 :-> v2 :-> b) =
NEBuild (SS (SS SZ))
(pair (pair nil a1) a2)
#idx
(let_ v1 (snd_ (fst_ #idx)) $
let_ v2 (NEDrop SZ (snd_ #idx)) $
NEDrop (SS (SS SZ)) b)
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3
sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e
unit :: NExpr env t -> NExpr env (TArr Z t)
unit = NEUnit
replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
replicate1i n a = NEReplicate1Inner n a
maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
maximum1i e = NEMaximum1Inner e
minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
minimum1i e = NEMinimum1Inner e
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConst ty x
idx0 :: NExpr env (TArr Z t) -> NExpr env t
idx0 = NEIdx0
(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
(.!) = NEIdx1
infixl 9 .!
(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx
infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
-> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
-> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
-> NExpr env a -> NExpr env b
-> NExpr env t
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==
(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .< b = oper (OLt knownScalTy) (pair a b)
infix 4 .<
(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>) = flip (.<)
infix 4 .>
(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .<= b = oper (OLe knownScalTy) (pair a b)
infix 4 .<=
(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>=) = flip (.<=)
infix 4 .>=
not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
|