diff options
author | Tom Smeding <tom@tomsmeding.com> | 2021-06-27 18:34:35 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2021-06-27 18:34:35 +0200 |
commit | d4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 (patch) | |
tree | 1ab301617043ac6df228ef617afa22633a01a671 /Language.hs | |
parent | 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (diff) |
Diffstat (limited to 'Language.hs')
-rw-r--r-- | Language.hs | 78 |
1 files changed, 68 insertions, 10 deletions
diff --git a/Language.hs b/Language.hs index e16cf7c..8ab6199 100644 --- a/Language.hs +++ b/Language.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} {-| This module is intended to be imported qualified, perhaps as @L@. -} module Language where @@ -7,6 +9,28 @@ import AST import Sink +-- Convention: matrices are represented in row-major: (((), y), x) +type DIM0 = () +type DIM1 = (DIM0, Int) +type DIM2 = (DIM1, Int) +type DIM3 = (DIM2, Int) + +class InferType a where infer :: Type a +instance InferType Int where infer = TInt +instance InferType Bool where infer = TBool +instance InferType Double where infer = TDouble +instance (InferType a, InferShapeType sh) => InferType (Array sh a) where infer = TArray inferST infer +instance InferType () where infer = TNil +instance (InferType a, InferType b) => InferType (a, b) where infer = TPair infer infer +instance (InferType a, InferType b) => InferType (a -> b) where infer = TFun infer infer + +class InferShapeType sh where inferST :: ShapeType sh +instance InferShapeType () where inferST = STZ +instance InferShapeType sh => InferShapeType (sh, Int) where inferST = STC inferST + +var :: InferType a => Idx env a -> Exp env a +var = Var infer + map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b) map f e = let ty@(TArray sht _) = typeof e @@ -18,17 +42,16 @@ map f e = (Index (Var ty (Succ Zero)) (Var sht' Zero))))) -sum :: Exp env (Array (Int, ()) Double) -> Exp env Double +sum :: Exp env (Array DIM1 Double) -> Exp env Double sum e = - let ty@(TArray sht _) = typeof e - in Let e - (Ifold sht - (Lam (TPair TDouble (TPair TInt TNil)) - (App (Const CAddF) (Pair - (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero)) - (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero)))))) - (Lit (LDouble 0)) - (Shape (Var ty Zero))) + Let e + (Ifold inferST + (Lam (TPair TDouble (TPair TNil TInt)) + (App (Const CAddF) (Pair + (Fst (var Zero)) + (Index (var (Succ Zero)) (Snd (var Zero)))))) + (Lit (LDouble 0)) + (Shape (var Zero))) -- | The two input arrays are assumed to be the same size. zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b)) @@ -50,3 +73,38 @@ oneHot sht sh idx = (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx))) (Lit (LDouble 1)) (Lit (LDouble 0)))) + +transpose :: Type a -> Exp env (Array DIM2 a -> Array DIM2 a) +transpose ty = + Lam (TArray inferST ty) + (Build inferST (Shape (Var (TArray inferST ty) Zero)) + (Lam infer (Index (Var (TArray inferST ty) (Succ Zero)) (Var infer Zero)))) + +eye :: Exp env (Int -> Array DIM2 Double) +eye = + Lam infer + (Build inferST (Pair (Pair (Lit LNil) (var Zero)) (var Zero)) + (Lam infer + (Cond (App (Const (CEq infer)) (Pair (Snd (var Zero)) (Snd (Fst (var Zero))))) + (Lit (LDouble 1)) + (Lit (LDouble 0))))) + +length :: Type a -> Exp env (Array DIM1 a -> Int) +length ty = Lam (TArray inferST ty) + (Snd (Shape (Var (TArray inferST ty) Zero))) + +vmmul :: Exp env (Array DIM1 Double -> Array DIM2 Double -> Array DIM1 Double) +vmmul = + Lam infer $ Lam infer $ + Build inferST + (Pair (Lit LNil) (Snd (Shape (var Zero)))) + (Lam infer $ + Ifold inferST + (Lam infer $ + App (Const CAddF) (Pair + (Fst (var Zero)) + (App (Const CMulF) (Pair + (Index (var (Succ (Succ (Succ Zero)))) (Snd (var Zero))) + (Index (var (Succ (Succ Zero))) (Pair (Pair (Lit LNil) (Snd (Snd (var Zero)))) (Snd (var (Succ Zero))))))))) + (Lit (LDouble 0)) + (Shape (var (Succ (Succ Zero))))) |