blob: e16cf7c433baf5857bc19a92a0ce5fab865e8160 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
{-# LANGUAGE GADTs #-}
{-| This module is intended to be imported qualified, perhaps as @L@. -}
module Language where
import AST
import Sink
map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b)
map f e =
let ty@(TArray sht _) = typeof e
sht' = shapeTypeType sht
in Let e
(Build sht (Shape (Var ty Zero))
(Lam sht'
(App (sinkExp2 f)
(Index (Var ty (Succ Zero))
(Var sht' Zero)))))
sum :: Exp env (Array (Int, ()) Double) -> Exp env Double
sum e =
let ty@(TArray sht _) = typeof e
in Let e
(Ifold sht
(Lam (TPair TDouble (TPair TInt TNil))
(App (Const CAddF) (Pair
(Fst (Var (TPair TDouble (TPair TInt TNil)) Zero))
(Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero))))))
(Lit (LDouble 0))
(Shape (Var ty Zero)))
-- | The two input arrays are assumed to be the same size.
zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b))
zip a b =
let tarr@(TArray sht _) = typeof a
idxt = shapeTypeType sht
in Let a
(Build sht
(Shape (Var tarr Zero))
(Lam idxt
(Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero))
(Index (sinkExp2 b) (Var idxt Zero)))))
oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double)
oneHot sht sh idx =
let idxt = shapeTypeType sht
in Build sht sh
(Lam idxt
(Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx)))
(Lit (LDouble 1))
(Lit (LDouble 0))))
|