aboutsummaryrefslogtreecommitdiff
path: root/Language.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-24 23:14:54 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-24 23:14:54 +0200
commit0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (patch)
tree0efeffb8b1b6d6126bc806209a2f5a64fb32c96f /Language.hs
Initial
Diffstat (limited to 'Language.hs')
-rw-r--r--Language.hs52
1 files changed, 52 insertions, 0 deletions
diff --git a/Language.hs b/Language.hs
new file mode 100644
index 0000000..e16cf7c
--- /dev/null
+++ b/Language.hs
@@ -0,0 +1,52 @@
+{-# LANGUAGE GADTs #-}
+
+{-| This module is intended to be imported qualified, perhaps as @L@. -}
+module Language where
+
+import AST
+import Sink
+
+
+map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b)
+map f e =
+ let ty@(TArray sht _) = typeof e
+ sht' = shapeTypeType sht
+ in Let e
+ (Build sht (Shape (Var ty Zero))
+ (Lam sht'
+ (App (sinkExp2 f)
+ (Index (Var ty (Succ Zero))
+ (Var sht' Zero)))))
+
+sum :: Exp env (Array (Int, ()) Double) -> Exp env Double
+sum e =
+ let ty@(TArray sht _) = typeof e
+ in Let e
+ (Ifold sht
+ (Lam (TPair TDouble (TPair TInt TNil))
+ (App (Const CAddF) (Pair
+ (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero))
+ (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero))))))
+ (Lit (LDouble 0))
+ (Shape (Var ty Zero)))
+
+-- | The two input arrays are assumed to be the same size.
+zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b))
+zip a b =
+ let tarr@(TArray sht _) = typeof a
+ idxt = shapeTypeType sht
+ in Let a
+ (Build sht
+ (Shape (Var tarr Zero))
+ (Lam idxt
+ (Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero))
+ (Index (sinkExp2 b) (Var idxt Zero)))))
+
+oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double)
+oneHot sht sh idx =
+ let idxt = shapeTypeType sht
+ in Build sht sh
+ (Lam idxt
+ (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx)))
+ (Lit (LDouble 1))
+ (Lit (LDouble 0))))