diff options
author | Tom Smeding <tom@tomsmeding.com> | 2021-06-24 23:14:54 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2021-06-24 23:14:54 +0200 |
commit | 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (patch) | |
tree | 0efeffb8b1b6d6126bc806209a2f5a64fb32c96f /Language.hs |
Initial
Diffstat (limited to 'Language.hs')
-rw-r--r-- | Language.hs | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/Language.hs b/Language.hs new file mode 100644 index 0000000..e16cf7c --- /dev/null +++ b/Language.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE GADTs #-} + +{-| This module is intended to be imported qualified, perhaps as @L@. -} +module Language where + +import AST +import Sink + + +map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b) +map f e = + let ty@(TArray sht _) = typeof e + sht' = shapeTypeType sht + in Let e + (Build sht (Shape (Var ty Zero)) + (Lam sht' + (App (sinkExp2 f) + (Index (Var ty (Succ Zero)) + (Var sht' Zero))))) + +sum :: Exp env (Array (Int, ()) Double) -> Exp env Double +sum e = + let ty@(TArray sht _) = typeof e + in Let e + (Ifold sht + (Lam (TPair TDouble (TPair TInt TNil)) + (App (Const CAddF) (Pair + (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero)) + (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero)))))) + (Lit (LDouble 0)) + (Shape (Var ty Zero))) + +-- | The two input arrays are assumed to be the same size. +zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b)) +zip a b = + let tarr@(TArray sht _) = typeof a + idxt = shapeTypeType sht + in Let a + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero)) + (Index (sinkExp2 b) (Var idxt Zero))))) + +oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double) +oneHot sht sh idx = + let idxt = shapeTypeType sht + in Build sht sh + (Lam idxt + (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx))) + (Lit (LDouble 1)) + (Lit (LDouble 0)))) |