aboutsummaryrefslogtreecommitdiff
path: root/Language.hs
blob: 8ab6199a8389cb83a1d262f9c28ec9669ab390c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}

{-| This module is intended to be imported qualified, perhaps as @L@. -}
module Language where

import AST
import Sink


-- Convention: matrices are represented in row-major: (((), y), x)
type DIM0 = ()
type DIM1 = (DIM0, Int)
type DIM2 = (DIM1, Int)
type DIM3 = (DIM2, Int)

class InferType a where infer :: Type a
instance InferType Int where infer = TInt
instance InferType Bool where infer = TBool
instance InferType Double where infer = TDouble
instance (InferType a, InferShapeType sh) => InferType (Array sh a) where infer = TArray inferST infer
instance InferType () where infer = TNil
instance (InferType a, InferType b) => InferType (a, b) where infer = TPair infer infer
instance (InferType a, InferType b) => InferType (a -> b) where infer = TFun infer infer

class InferShapeType sh where inferST :: ShapeType sh
instance InferShapeType () where inferST = STZ
instance InferShapeType sh => InferShapeType (sh, Int) where inferST = STC inferST

var :: InferType a => Idx env a -> Exp env a
var = Var infer

map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b)
map f e =
    let ty@(TArray sht _) = typeof e
        sht' = shapeTypeType sht
    in Let e
           (Build sht (Shape (Var ty Zero))
                      (Lam sht'
                         (App (sinkExp2 f)
                              (Index (Var ty (Succ Zero))
                                     (Var sht' Zero)))))

sum :: Exp env (Array DIM1 Double) -> Exp env Double
sum e =
    Let e
        (Ifold inferST
               (Lam (TPair TDouble (TPair TNil TInt))
                    (App (Const CAddF) (Pair
                       (Fst (var Zero))
                       (Index (var (Succ Zero)) (Snd (var Zero))))))
               (Lit (LDouble 0))
               (Shape (var Zero)))

-- | The two input arrays are assumed to be the same size.
zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b))
zip a b =
    let tarr@(TArray sht _) = typeof a
        idxt = shapeTypeType sht
    in Let a
           (Build sht
                  (Shape (Var tarr Zero))
                  (Lam idxt
                     (Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero))
                           (Index (sinkExp2 b) (Var idxt Zero)))))

oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double)
oneHot sht sh idx =
    let idxt = shapeTypeType sht
    in Build sht sh
             (Lam idxt
                (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx)))
                      (Lit (LDouble 1))
                      (Lit (LDouble 0))))

transpose :: Type a -> Exp env (Array DIM2 a -> Array DIM2 a)
transpose ty =
    Lam (TArray inferST ty)
        (Build inferST (Shape (Var (TArray inferST ty) Zero))
                       (Lam infer (Index (Var (TArray inferST ty) (Succ Zero)) (Var infer Zero))))

eye :: Exp env (Int -> Array DIM2 Double)
eye =
    Lam infer
        (Build inferST (Pair (Pair (Lit LNil) (var Zero)) (var Zero))
                       (Lam infer
                          (Cond (App (Const (CEq infer)) (Pair (Snd (var Zero)) (Snd (Fst (var Zero)))))
                                (Lit (LDouble 1))
                                (Lit (LDouble 0)))))

length :: Type a -> Exp env (Array DIM1 a -> Int)
length ty = Lam (TArray inferST ty)
                (Snd (Shape (Var (TArray inferST ty) Zero)))

vmmul :: Exp env (Array DIM1 Double -> Array DIM2 Double -> Array DIM1 Double)
vmmul =
    Lam infer $ Lam infer $
      Build inferST
        (Pair (Lit LNil) (Snd (Shape (var Zero))))
        (Lam infer $
           Ifold inferST
             (Lam infer $
                App (Const CAddF) (Pair
                  (Fst (var Zero))
                  (App (Const CMulF) (Pair
                     (Index (var (Succ (Succ (Succ Zero)))) (Snd (var Zero)))
                     (Index (var (Succ (Succ Zero))) (Pair (Pair (Lit LNil) (Snd (Snd (var Zero)))) (Snd (var (Succ Zero)))))))))
             (Lit (LDouble 0))
             (Shape (var (Succ (Succ Zero)))))