summaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
blob: 387915b511ef66660791ef89cb962bbc3566e9e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.AST where

import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..))

import Array
import AST
import CHAD.Types
import Data


type NExpr :: [(Symbol, Ty)] -> Ty -> Type
data NExpr env t where
  -- lambda calculus
  NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
  NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t

  -- environment management
  NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t

  -- base types
  NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
  NEFst :: NExpr env (TPair a b) -> NExpr env a
  NESnd :: NExpr env (TPair a b) -> NExpr env b
  NENil :: NExpr env TNil
  NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
  NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
  NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c

  -- array operations
  NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
  NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
  NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
  NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
  NEUnit :: NExpr env t -> NExpr env (TArr Z t)
  NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
  NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
  NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))

  -- expression operations
  NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
  NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
  NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
  NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
  NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
  NEOp :: SOp a t -> NExpr env a -> NExpr env t

  -- custom derivatives
  NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t  -- ^ regular operation
           -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)  -- ^ CHAD forward pass
           -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)  -- ^ CHAD reverse derivative
           -> NExpr env a -> NExpr env b
           -> NExpr env t

  -- partiality
  NEError :: STy a -> String -> NExpr env a

  -- embedded unnamed expressions
  NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)

type family Lookup name env where
  Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
  Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
  Lookup name ('(name, t) : env) = t
  Lookup name (_ : env) = Lookup name env

type family DropNth i env where
  DropNth Z (_ : env) = env
  DropNth (S i) (p : env) = p : DropNth i env

data Var name t = Var (SSymbol name) (STy t)
  deriving (Show)

instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
  a + b = NEOp (OAdd knownScalTy) (NEPair a b)
  a * b = NEOp (OMul knownScalTy) (NEPair a b)
  negate e = NEOp (ONeg knownScalTy) e
  abs = error "abs undefined for NExpr"
  signum = error "signum undefined for NExpr"
  fromInteger =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty . fromInteger

instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
      => Fractional (NExpr env t) where
  recip e = NEOp (ORecip knownScalTy) e
  fromRational =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty . fromRational

instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
      => Floating (NExpr env t) where
  pi =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty pi
  exp = NEOp (OExp knownScalTy)
  log = NEOp (OExp knownScalTy)
  sin = undefined ; cos = undefined ; tan = undefined
  asin = undefined ; acos = undefined ; atan = undefined
  sinh = undefined ; cosh = undefined
  asinh = undefined ; acosh = undefined ; atanh = undefined

instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
  fromLabel = Var symbolSing knownTy

instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
  fromLabel = NEVar (fromLabel @name)

-- | Innermost variable variable on the outside, on the right.
data NEnv env where
  NTop :: NEnv '[]
  NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)

-- | First (outermost) parameter on the outside, on the left.
-- * env: environment of this function (grows as you go deeper inside lambdas)
-- * env': environment of the body of the function
-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
data NFun env env' t where
  NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
  NBody :: NExpr env' t -> NFun env' env' t

type family UnName env where
  UnName '[] = '[]
  UnName ('(name, t) : env) = t : UnName env

envFromNEnv :: NEnv env -> SList STy (UnName env)
envFromNEnv NTop = SNil
envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env

inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
inlineNFun fun args = NEUnnamed (fromNamed fun) args

fromNamed :: NFun '[] env t -> Ex (UnName env) t
fromNamed = fromNamedFun NTop

-- | Some of the parameters have already been put in the environment; some
-- haven't. Transfer all parameters to the left into the environment.
--
--   [] `fromNamedFun` λx y z. E
-- = []:x `fromNamedFun` λy z. E
-- = []:x:y `fromNamedFun` λz. E
-- = []:x:y:z `fromNamedFun` λ. E
-- = []:x:y:z `fromNamedExpr` E
fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
fromNamedFun env (NBody e) = fromNamedExpr env e

fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
fromNamedExpr val = \case
  NEVar var@(Var _ ty)
    | Just idx <- find var val -> EVar ext ty idx
    | otherwise -> error "Variable out of scope in conversion from surface \
                         \expression to De Bruijn expression"
  NELet n a b -> ELet ext (go a) (lambda val n b)

  NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)

  NEPair a b -> EPair ext (go a) (go b)
  NEFst e -> EFst ext (go e)
  NESnd e -> ESnd ext (go e)
  NENil -> ENil ext
  NEInl t e -> EInl ext t (go e)
  NEInr t e -> EInr ext t (go e)
  NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)

  NEConstArr n t x -> EConstArr ext n t x
  NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
  NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
  NESum1Inner e -> ESum1Inner ext (go e)
  NEUnit e -> EUnit ext (go e)
  NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
  NEMaximum1Inner e -> EMaximum1Inner ext (go e)
  NEMinimum1Inner e -> EMinimum1Inner ext (go e)

  NEConst t x -> EConst ext t x
  NEIdx0 e -> EIdx0 ext (go e)
  NEIdx1 a b -> EIdx1 ext (go a) (go b)
  NEIdx a b -> EIdx ext (go a) (go b)
  NEShape e -> EShape ext (go e)
  NEOp op e -> EOp ext op (go e)

  NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
    ECustom ext ta tb ttape
            (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
            (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
            (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
            (go e1) (go e2)

  NEError t s -> EError ext t s

  NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
  where
    go :: NExpr env t' -> Ex (UnName env) t'
    go = fromNamedExpr val

    find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
    find _ NTop = Nothing
    find var@(Var s ty) (val' `NPush` Var s' ty')
      | Just Refl <- testEquality s s'
      , Just Refl <- testEquality ty ty'
      = Just IZ
      | otherwise
      = IS <$> find var val'

    lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
    lambda val' var e = fromNamedExpr (val' `NPush` var) e

    lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
    lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e

    injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
    injectWrapLet e SNil = e
    injectWrapLet e (arg `SCons` args) =
      injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
                    args

dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
dropNth SZ (val `NPush` _) = val
dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
dropNth _ NTop = error "DropNth: index out of range"

dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
dropNthW SZ (_ `NPush` _) = WSink
dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
dropNthW _ NTop = error "DropNth: index out of range"