aboutsummaryrefslogtreecommitdiff
path: root/Language.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
commitd4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 (patch)
tree1ab301617043ac6df228ef617afa22633a01a671 /Language.hs
parent0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (diff)
Diffstat (limited to 'Language.hs')
-rw-r--r--Language.hs78
1 files changed, 68 insertions, 10 deletions
diff --git a/Language.hs b/Language.hs
index e16cf7c..8ab6199 100644
--- a/Language.hs
+++ b/Language.hs
@@ -1,4 +1,6 @@
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
{-| This module is intended to be imported qualified, perhaps as @L@. -}
module Language where
@@ -7,6 +9,28 @@ import AST
import Sink
+-- Convention: matrices are represented in row-major: (((), y), x)
+type DIM0 = ()
+type DIM1 = (DIM0, Int)
+type DIM2 = (DIM1, Int)
+type DIM3 = (DIM2, Int)
+
+class InferType a where infer :: Type a
+instance InferType Int where infer = TInt
+instance InferType Bool where infer = TBool
+instance InferType Double where infer = TDouble
+instance (InferType a, InferShapeType sh) => InferType (Array sh a) where infer = TArray inferST infer
+instance InferType () where infer = TNil
+instance (InferType a, InferType b) => InferType (a, b) where infer = TPair infer infer
+instance (InferType a, InferType b) => InferType (a -> b) where infer = TFun infer infer
+
+class InferShapeType sh where inferST :: ShapeType sh
+instance InferShapeType () where inferST = STZ
+instance InferShapeType sh => InferShapeType (sh, Int) where inferST = STC inferST
+
+var :: InferType a => Idx env a -> Exp env a
+var = Var infer
+
map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b)
map f e =
let ty@(TArray sht _) = typeof e
@@ -18,17 +42,16 @@ map f e =
(Index (Var ty (Succ Zero))
(Var sht' Zero)))))
-sum :: Exp env (Array (Int, ()) Double) -> Exp env Double
+sum :: Exp env (Array DIM1 Double) -> Exp env Double
sum e =
- let ty@(TArray sht _) = typeof e
- in Let e
- (Ifold sht
- (Lam (TPair TDouble (TPair TInt TNil))
- (App (Const CAddF) (Pair
- (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero))
- (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero))))))
- (Lit (LDouble 0))
- (Shape (Var ty Zero)))
+ Let e
+ (Ifold inferST
+ (Lam (TPair TDouble (TPair TNil TInt))
+ (App (Const CAddF) (Pair
+ (Fst (var Zero))
+ (Index (var (Succ Zero)) (Snd (var Zero))))))
+ (Lit (LDouble 0))
+ (Shape (var Zero)))
-- | The two input arrays are assumed to be the same size.
zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b))
@@ -50,3 +73,38 @@ oneHot sht sh idx =
(Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx)))
(Lit (LDouble 1))
(Lit (LDouble 0))))
+
+transpose :: Type a -> Exp env (Array DIM2 a -> Array DIM2 a)
+transpose ty =
+ Lam (TArray inferST ty)
+ (Build inferST (Shape (Var (TArray inferST ty) Zero))
+ (Lam infer (Index (Var (TArray inferST ty) (Succ Zero)) (Var infer Zero))))
+
+eye :: Exp env (Int -> Array DIM2 Double)
+eye =
+ Lam infer
+ (Build inferST (Pair (Pair (Lit LNil) (var Zero)) (var Zero))
+ (Lam infer
+ (Cond (App (Const (CEq infer)) (Pair (Snd (var Zero)) (Snd (Fst (var Zero)))))
+ (Lit (LDouble 1))
+ (Lit (LDouble 0)))))
+
+length :: Type a -> Exp env (Array DIM1 a -> Int)
+length ty = Lam (TArray inferST ty)
+ (Snd (Shape (Var (TArray inferST ty) Zero)))
+
+vmmul :: Exp env (Array DIM1 Double -> Array DIM2 Double -> Array DIM1 Double)
+vmmul =
+ Lam infer $ Lam infer $
+ Build inferST
+ (Pair (Lit LNil) (Snd (Shape (var Zero))))
+ (Lam infer $
+ Ifold inferST
+ (Lam infer $
+ App (Const CAddF) (Pair
+ (Fst (var Zero))
+ (App (Const CMulF) (Pair
+ (Index (var (Succ (Succ (Succ Zero)))) (Snd (var Zero)))
+ (Index (var (Succ (Succ Zero))) (Pair (Pair (Lit LNil) (Snd (Snd (var Zero)))) (Snd (var (Succ Zero)))))))))
+ (Lit (LDouble 0))
+ (Shape (var (Succ (Succ Zero)))))