aboutsummaryrefslogtreecommitdiff
path: root/Eval.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Eval.hs')
-rw-r--r--Eval.hs128
1 files changed, 128 insertions, 0 deletions
diff --git a/Eval.hs b/Eval.hs
new file mode 100644
index 0000000..19265bc
--- /dev/null
+++ b/Eval.hs
@@ -0,0 +1,128 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+module Eval (
+ eval,
+) where
+
+import Data.List (foldl')
+import qualified Data.Vector as V
+
+import AST
+
+
+data Val env where
+ Top :: Val '[]
+ Push :: Val env -> a -> Val (a ': env)
+
+prj :: Val env -> Idx env a -> a
+prj (Push _ x) Zero = x
+prj (Push env _) (Succ i) = prj env i
+
+eval :: Exp '[] a -> a
+eval = eval' Top
+
+eval' :: forall env a. Val env -> Exp env a -> a
+eval' env = \case
+ App f a -> rec f (rec a)
+ Lam _ e -> \x -> eval' (Push env x) e
+ Var _ i -> prj env i
+ Let a e -> eval' (Push env (rec a)) e
+ Lit l -> evalL l
+ Cond c a b -> if rec c then rec a else rec b
+ Const c -> evalC c
+ Pair a b -> (rec a, rec b)
+ Fst e -> fst (rec e)
+ Snd e -> snd (rec e)
+ Build sht she fe ->
+ let TFun _ ty = typeof fe
+ in build ty sht (rec she) (rec fe)
+ Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she)
+ Index a i -> index (rec a) (rec i)
+ Shape a -> shape (rec a)
+ Undef t -> error ("eval: Undef of type " ++ show t)
+ where rec :: Exp env t -> t
+ rec = eval' env
+
+evalL :: Literal a -> a
+evalL (LInt n) = n
+evalL (LBool b) = b
+evalL (LDouble d) = d
+evalL (LArray a) = a
+evalL (LShape Z) = ()
+evalL (LShape sh) = unshape sh
+evalL LNil = ()
+evalL (LPair a b) = (evalL a, evalL b)
+
+evalC :: Constant a -> a
+evalC CAddI = uncurry (+)
+evalC CSubI = uncurry (-)
+evalC CMulI = uncurry (*)
+evalC CDivI = uncurry div
+evalC CAddF = uncurry (+)
+evalC CSubF = uncurry (-)
+evalC CMulF = uncurry (*)
+evalC CDivF = uncurry (/)
+evalC CLog = log
+evalC CExp = exp
+evalC CtoF = fromIntegral
+evalC CRound = round
+evalC CLtI = uncurry (<)
+evalC CLeI = uncurry (<=)
+evalC CLtF = uncurry (<)
+evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==)
+ | otherwise = error ("eval: Cannot Eq compare values of type " ++ show t)
+evalC CAnd = uncurry (&&)
+evalC COr = uncurry (||)
+evalC CNot = not
+
+build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a
+build ty sht sh f =
+ let sh' = toshape sht sh
+ in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i)))
+
+ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a
+ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh))
+
+index :: Array sh a -> sh -> a
+index (Array sh _ v) idx = v V.! tolinear sh idx
+
+shape :: Array sh a -> sh
+shape (Array sh _ _) = unshape sh
+
+enumshape :: Shape sh -> [sh]
+enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh))
+ where
+ next :: Shape sh -> sh -> sh
+ next Z () = ()
+ next (sh' :. n) (idx, i)
+ | i < n = (idx, i + 1)
+ | otherwise = (next sh' idx, 0)
+
+ zeroshape :: Shape sh -> sh
+ zeroshape Z = ()
+ zeroshape (sh' :. _) = (zeroshape sh', 0)
+
+unshape :: Shape sh -> sh
+unshape Z = ()
+unshape (sh :. n) = (unshape sh, n)
+
+toshape :: ShapeType sh -> sh -> Shape sh
+toshape STZ () = Z
+toshape (STC sht) (sh, n) = toshape sht sh :. n
+
+tolinear :: Shape sh -> sh -> Int
+tolinear Z () = 0
+tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i
+
+fromlinear :: Shape sh -> Int -> sh
+fromlinear Z _ = ()
+fromlinear (sh :. n) i =
+ let (q, r) = i `divMod` n
+ in (fromlinear sh q, r)
+
+shapesize :: Shape sh -> Int
+shapesize Z = 1
+shapesize (sh :. n) = n * shapesize sh