{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-| This module is intended to be imported qualified, perhaps as @L@. -} module Language where import AST import Sink -- Convention: matrices are represented in row-major: (((), y), x) type DIM0 = () type DIM1 = (DIM0, Int) type DIM2 = (DIM1, Int) type DIM3 = (DIM2, Int) class InferType a where infer :: Type a instance InferType Int where infer = TInt instance InferType Bool where infer = TBool instance InferType Double where infer = TDouble instance (InferType a, InferShapeType sh) => InferType (Array sh a) where infer = TArray inferST infer instance InferType () where infer = TNil instance (InferType a, InferType b) => InferType (a, b) where infer = TPair infer infer instance (InferType a, InferType b) => InferType (a -> b) where infer = TFun infer infer class InferShapeType sh where inferST :: ShapeType sh instance InferShapeType () where inferST = STZ instance InferShapeType sh => InferShapeType (sh, Int) where inferST = STC inferST var :: InferType a => Idx env a -> Exp env a var = Var infer map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b) map f e = let ty@(TArray sht _) = typeof e sht' = shapeTypeType sht in Let e (Build sht (Shape (Var ty Zero)) (Lam sht' (App (sinkExp2 f) (Index (Var ty (Succ Zero)) (Var sht' Zero))))) sum :: Exp env (Array DIM1 Double) -> Exp env Double sum e = Let e (Ifold inferST (Lam (TPair TDouble (TPair TNil TInt)) (App (Const CAddF) (Pair (Fst (var Zero)) (Index (var (Succ Zero)) (Snd (var Zero)))))) (Lit (LDouble 0)) (Shape (var Zero))) -- | The two input arrays are assumed to be the same size. zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b)) zip a b = let tarr@(TArray sht _) = typeof a idxt = shapeTypeType sht in Let a (Build sht (Shape (Var tarr Zero)) (Lam idxt (Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero)) (Index (sinkExp2 b) (Var idxt Zero))))) oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double) oneHot sht sh idx = let idxt = shapeTypeType sht in Build sht sh (Lam idxt (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx))) (Lit (LDouble 1)) (Lit (LDouble 0)))) transpose :: Type a -> Exp env (Array DIM2 a -> Array DIM2 a) transpose ty = Lam (TArray inferST ty) (Build inferST (Shape (Var (TArray inferST ty) Zero)) (Lam infer (Index (Var (TArray inferST ty) (Succ Zero)) (Var infer Zero)))) eye :: Exp env (Int -> Array DIM2 Double) eye = Lam infer (Build inferST (Pair (Pair (Lit LNil) (var Zero)) (var Zero)) (Lam infer (Cond (App (Const (CEq infer)) (Pair (Snd (var Zero)) (Snd (Fst (var Zero))))) (Lit (LDouble 1)) (Lit (LDouble 0))))) length :: Type a -> Exp env (Array DIM1 a -> Int) length ty = Lam (TArray inferST ty) (Snd (Shape (Var (TArray inferST ty) Zero))) vmmul :: Exp env (Array DIM1 Double -> Array DIM2 Double -> Array DIM1 Double) vmmul = Lam infer $ Lam infer $ Build inferST (Pair (Lit LNil) (Snd (Shape (var Zero)))) (Lam infer $ Ifold inferST (Lam infer $ App (Const CAddF) (Pair (Fst (var Zero)) (App (Const CMulF) (Pair (Index (var (Succ (Succ (Succ Zero)))) (Snd (var Zero))) (Index (var (Succ (Succ Zero))) (Pair (Pair (Lit LNil) (Snd (Snd (var Zero)))) (Snd (var (Succ Zero))))))))) (Lit (LDouble 0)) (Shape (var (Succ (Succ Zero)))))