1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.AST where
import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text))
import Array
import AST
import CHAD.Types
import Data
type NExpr :: [(Symbol, Ty)] -> Ty -> Type
data NExpr env t where
-- lambda calculus
NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
-- environment management
NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t
-- base types
NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
NEFst :: NExpr env (TPair a b) -> NExpr env a
NESnd :: NExpr env (TPair a b) -> NExpr env b
NENil :: NExpr env TNil
NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
NEOp :: SOp a t -> NExpr env a -> NExpr env t
-- custom derivatives
NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation
-> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass
-> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative
-> NExpr env a -> NExpr env b
-> NExpr env t
-- partiality
NEError :: STy a -> String -> NExpr env a
-- embedded unnamed expressions
NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)
type family Lookup name env where
Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
Lookup name ('(name, t) : env) = t
Lookup name (_ : env) = Lookup name env
type family DropNth i env where
DropNth Z (_ : env) = env
DropNth (S i) (p : env) = p : DropNth i env
data Var name t = Var (SSymbol name) (STy t)
deriving (Show)
instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
a + b = NEOp (OAdd knownScalTy) (NEPair a b)
a * b = NEOp (OMul knownScalTy) (NEPair a b)
negate e = NEOp (ONeg knownScalTy) e
abs = error "abs undefined for NExpr"
signum = error "signum undefined for NExpr"
fromInteger =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConst ty . fromInteger
instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
=> Fractional (NExpr env t) where
recip e = NEOp (ORecip knownScalTy) e
fromRational =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConst ty . fromRational
instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
=> Floating (NExpr env t) where
pi =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConst ty pi
exp = NEOp (OExp knownScalTy)
log = NEOp (OExp knownScalTy)
sin = undefined ; cos = undefined ; tan = undefined
asin = undefined ; acos = undefined ; atan = undefined
sinh = undefined ; cosh = undefined
asinh = undefined ; acosh = undefined ; atanh = undefined
instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
fromLabel = Var symbolSing knownTy
instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
fromLabel = NEVar (fromLabel @name)
-- | Innermost variable variable on the outside, on the right.
data NEnv env where
NTop :: NEnv '[]
NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)
-- | First (outermost) parameter on the outside, on the left.
-- * env: environment of this function (grows as you go deeper inside lambdas)
-- * env': environment of the body of the function
-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
data NFun env env' t where
NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
NBody :: NExpr env' t -> NFun env' env' t
type family UnName env where
UnName '[] = '[]
UnName ('(name, t) : env) = t : UnName env
envFromNEnv :: NEnv env -> SList STy (UnName env)
envFromNEnv NTop = SNil
envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env
inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
inlineNFun fun args = NEUnnamed (fromNamed fun) args
fromNamed :: NFun '[] env t -> Ex (UnName env) t
fromNamed = fromNamedFun NTop
-- | Some of the parameters have already been put in the environment; some
-- haven't. Transfer all parameters to the left into the environment.
--
-- [] `fromNamedFun` λx y z. E
-- = []:x `fromNamedFun` λy z. E
-- = []:x:y `fromNamedFun` λz. E
-- = []:x:y:z `fromNamedFun` λ. E
-- = []:x:y:z `fromNamedExpr` E
fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
fromNamedFun env (NBody e) = fromNamedExpr env e
fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
fromNamedExpr val = \case
NEVar var@(Var _ ty)
| Just idx <- find var val -> EVar ext ty idx
| otherwise -> error "Variable out of scope in conversion from surface \
\expression to De Bruijn expression"
NELet n a b -> ELet ext (go a) (lambda val n b)
NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)
NEPair a b -> EPair ext (go a) (go b)
NEFst e -> EFst ext (go e)
NESnd e -> ESnd ext (go e)
NENil -> ENil ext
NEInl t e -> EInl ext t (go e)
NEInr t e -> EInr ext t (go e)
NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
NEMaximum1Inner e -> EMaximum1Inner ext (go e)
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
NEIdx1 a b -> EIdx1 ext (go a) (go b)
NEIdx a b -> EIdx ext (go a) (go b)
NEShape e -> EShape ext (go e)
NEOp op e -> EOp ext op (go e)
NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
ECustom ext ta tb ttape
(fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
(fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
(fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
(go e1) (go e2)
NEError t s -> EError ext t s
NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
where
go :: NExpr env t' -> Ex (UnName env) t'
go = fromNamedExpr val
find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
find _ NTop = Nothing
find var@(Var s ty) (val' `NPush` Var s' ty')
| Just Refl <- testEquality s s'
, Just Refl <- testEquality ty ty'
= Just IZ
| otherwise
= IS <$> find var val'
lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
lambda val' var e = fromNamedExpr (val' `NPush` var) e
lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e
injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
injectWrapLet e SNil = e
injectWrapLet e (arg `SCons` args) =
injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
args
dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
dropNth SZ (val `NPush` _) = val
dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
dropNth _ NTop = error "DropNth: index out of range"
dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
dropNthW SZ (_ `NPush` _) = WSink
dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
dropNthW _ NTop = error "DropNth: index out of range"
|