summaryrefslogtreecommitdiff
path: root/src/Language.hs
blob: c2b844e656355b918ebd4c753e4eb6e16bce34f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE TypeOperators #-}
module Language (
  fromNamed,
  NExpr,
  module Language,
  Lookup,
) where

import Array
import AST
import Data
import Language.AST


data a :-> b = a :-> b
  deriving (Show)
infixr 0 :->


body :: NExpr env t -> NFun env env t
body = NBody

lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam


let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet

pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
pair = NEPair

fst_ :: NExpr env (TPair a b) -> NExpr env a
fst_ = NEFst

snd_ :: NExpr env (TPair a b) -> NExpr env b
snd_ = NESnd

nil :: NExpr env TNil
nil = NENil

inl :: STy b -> NExpr env a -> NExpr env (TEither a b)
inl = NEInl

inr :: STy a -> NExpr env b -> NExpr env (TEither a b)
inr = NEInr

case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2

constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
  let ty = knownScalTy
  in case scalRepIsShow ty of
       Dict -> NEConstArr knownNat ty x

build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild1 a v b

build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b

fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3

sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e

unit :: NExpr env t -> NExpr env (TArr Z t)
unit = NEUnit

replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
replicate1i n a = NEReplicate1Inner n a

const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
  let ty = knownScalTy
  in case scalRepIsShow ty of
       Dict -> NEConst ty x

idx0 :: NExpr env (TArr Z t) -> NExpr env t
idx0 = NEIdx0

(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
(.!) = NEIdx1
infixl 9 .!

(!) :: KnownNat n => NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx knownNat
infixl 9 !

shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape

oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp

error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s

(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==

(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .< b = oper (OLt knownScalTy) (pair a b)
infix 4 .<

(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>) = flip (.<)
infix 4 .>

(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .<= b = oper (OLe knownScalTy) (pair a b)
infix 4 .<=

(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>=) = flip (.<=)
infix 4 .>=

not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot

-- | The "_" variables in scope are unusable and should be ignored. With a
-- weakening function on NExprs they could be hidden.
if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b)