summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs152
1 files changed, 150 insertions, 2 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index afc50f9..7ffb14c 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,8 +1,156 @@
-module Interpreter where
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE FlexibleContexts #-}
+module Interpreter (
+ interpret,
+ interpret',
+ Value,
+ NoAccum(..),
+ unAccum,
+) where
+
+import Data.Int (Int64)
+import Data.Proxy
import AST
-import Interpreter.Array
+import Data
+import Array
import Interpreter.Accum
+import Interpreter.Rep
+import Control.Monad (foldM)
+
+
+interpret :: NoAccum t => Ex '[] t -> Rep t
+interpret e = runAcM (go e)
+ where
+ go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t)
+ go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e'
+
+newtype Value s t = Value (Rep' s t)
+
+interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t)
+interpret' env = \case
+ EVar _ _ i -> case slistIdx env i of Value x -> return x
+ ELet _ a b -> do
+ x <- interpret' env a
+ interpret' (Value x `SCons` env) b
+ EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
+ EFst _ e -> fst <$> interpret' env e
+ ESnd _ e -> snd <$> interpret' env e
+ ENil _ -> return ()
+ EInl _ _ e -> Left <$> interpret' env e
+ EInr _ _ e -> Right <$> interpret' env e
+ ECase _ e a b -> interpret' env e >>= \case
+ Left x -> interpret' (Value x `SCons` env) a
+ Right y -> interpret' (Value y `SCons` env) b
+ EConstArr _ _ _ v -> return v
+ EBuild1 _ a b -> do
+ n <- fromIntegral @Int64 @Int <$> interpret' env a
+ arrayGenerateLinM (ShNil `ShCons` n)
+ (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
+ EBuild _ dim a b -> do
+ sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a
+ arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b)
+ EFold1Inner _ a b -> do
+ let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
+ arr <- interpret' env b
+ let sh `ShCons` n = arrayShape arr
+ arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ ESum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
+ EReplicate1Inner _ a b -> do
+ n <- fromIntegral @Int64 @Int <$> interpret' env a
+ arr <- interpret' env b
+ let sh = arrayShape arr
+ arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
+ EConst _ _ v -> return v
+ EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
+ EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
+ EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b)
+ EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e
+ EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e
+ EWith e1 e2 -> do
+ initval <- interpret' env e1
+ withAccum (typeOf e1) initval $ \accum ->
+ interpret' (Value accum `SCons` env) e2
+ EAccum i e1 e2 e3 -> do
+ idx <- interpret' env e1
+ val <- interpret' env e2
+ accum <- interpret' env e3
+ accumAdd accum i idx val
+ EError _ s -> error $ "Interpreter: Program threw error: " ++ s
+
+interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t
+interpretOp _ op arg = case op of
+ OAdd st -> numericIsNum st $ uncurry (+) arg
+ OMul st -> numericIsNum st $ uncurry (*) arg
+ ONeg st -> numericIsNum st $ negate arg
+ OLt st -> numericIsNum st $ uncurry (<) arg
+ OLe st -> numericIsNum st $ uncurry (<=) arg
+ OEq st -> numericIsNum st $ uncurry (==) arg
+ ONot -> not arg
+ OIf -> if arg then Left () else Right ()
+
+numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
+numericIsNum STI32 = id
+numericIsNum STI64 = id
+numericIsNum STF32 = id
+numericIsNum STF64 = id
+
+unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m))
+ -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n
+unTupRepIdx _ nil _ SZ _ = nil
+unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
+
+tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int))
+ -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx))
+tupRepIdx _ _ SZ _ = ()
+tupRepIdx p uncons (SS n) tup =
+ let (tup', i) = uncons tup
+ in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i)
+
+ixUncons :: Index (S n) -> (Index n, Int)
+ixUncons (IxCons idx i) = (idx, i)
+
+shUncons :: Shape (S n) -> (Shape n, Int)
+shUncons (ShCons idx i) = (idx, i)
+class NoAccum t where
+ noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t
+instance NoAccum TNil where
+ noAccum _ _ = Refl
+instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where
+ noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
+instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where
+ noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
+instance NoAccum t => NoAccum (TArr n t) where
+ noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl
+instance NoAccum (TScal t) where
+ noAccum _ _ = Refl
+unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t))
+unAccum _ STNil = Just Dict
+unAccum p (STPair t1 t2)
+ | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
+ | otherwise = Nothing
+unAccum p (STEither t1 t2)
+ | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
+ | otherwise = Nothing
+unAccum p (STArr _ t)
+ | Just Dict <- unAccum p t = Just Dict
+ | otherwise = Nothing
+unAccum _ STScal{} = Just Dict
+unAccum _ STAccum{} = Nothing
+foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
+foldl1M _ [] = error "foldl1M: empty list"
+foldl1M f (tophead : toptail) = foldM f tophead toptail