diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 152 |
1 files changed, 150 insertions, 2 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index afc50f9..7ffb14c 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,8 +1,156 @@ -module Interpreter where +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE FlexibleContexts #-} +module Interpreter ( + interpret, + interpret', + Value, + NoAccum(..), + unAccum, +) where + +import Data.Int (Int64) +import Data.Proxy import AST -import Interpreter.Array +import Data +import Array import Interpreter.Accum +import Interpreter.Rep +import Control.Monad (foldM) + + +interpret :: NoAccum t => Ex '[] t -> Rep t +interpret e = runAcM (go e) + where + go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) + go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' + +newtype Value s t = Value (Rep' s t) + +interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) +interpret' env = \case + EVar _ _ i -> case slistIdx env i of Value x -> return x + ELet _ a b -> do + x <- interpret' env a + interpret' (Value x `SCons` env) b + EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b + EFst _ e -> fst <$> interpret' env e + ESnd _ e -> snd <$> interpret' env e + ENil _ -> return () + EInl _ _ e -> Left <$> interpret' env e + EInr _ _ e -> Right <$> interpret' env e + ECase _ e a b -> interpret' env e >>= \case + Left x -> interpret' (Value x `SCons` env) a + Right y -> interpret' (Value y `SCons` env) b + EConstArr _ _ _ v -> return v + EBuild1 _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arrayGenerateLinM (ShNil `ShCons` n) + (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) + EBuild _ dim a b -> do + sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) + EFold1Inner _ a b -> do + let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a + arr <- interpret' env b + let sh `ShCons` n = arrayShape arr + arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + ESum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) + EReplicate1Inner _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arr <- interpret' env b + let sh = arrayShape arr + arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx)) + EConst _ _ v -> return v + EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e + EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) + EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e + EWith e1 e2 -> do + initval <- interpret' env e1 + withAccum (typeOf e1) initval $ \accum -> + interpret' (Value accum `SCons` env) e2 + EAccum i e1 e2 e3 -> do + idx <- interpret' env e1 + val <- interpret' env e2 + accum <- interpret' env e3 + accumAdd accum i idx val + EError _ s -> error $ "Interpreter: Program threw error: " ++ s + +interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t +interpretOp _ op arg = case op of + OAdd st -> numericIsNum st $ uncurry (+) arg + OMul st -> numericIsNum st $ uncurry (*) arg + ONeg st -> numericIsNum st $ negate arg + OLt st -> numericIsNum st $ uncurry (<) arg + OLe st -> numericIsNum st $ uncurry (<=) arg + OEq st -> numericIsNum st $ uncurry (==) arg + ONot -> not arg + OIf -> if arg then Left () else Right () + +numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r +numericIsNum STI32 = id +numericIsNum STI64 = id +numericIsNum STF32 = id +numericIsNum STF64 = id + +unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n +unTupRepIdx _ nil _ SZ _ = nil +unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i + +tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) +tupRepIdx _ _ SZ _ = () +tupRepIdx p uncons (SS n) tup = + let (tup', i) = uncons tup + in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) + +ixUncons :: Index (S n) -> (Index n, Int) +ixUncons (IxCons idx i) = (idx, i) + +shUncons :: Shape (S n) -> (Shape n, Int) +shUncons (ShCons idx i) = (idx, i) +class NoAccum t where + noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t +instance NoAccum TNil where + noAccum _ _ = Refl +instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where + noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl +instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where + noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl +instance NoAccum t => NoAccum (TArr n t) where + noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl +instance NoAccum (TScal t) where + noAccum _ _ = Refl +unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) +unAccum _ STNil = Just Dict +unAccum p (STPair t1 t2) + | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict + | otherwise = Nothing +unAccum p (STEither t1 t2) + | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict + | otherwise = Nothing +unAccum p (STArr _ t) + | Just Dict <- unAccum p t = Just Dict + | otherwise = Nothing +unAccum _ STScal{} = Just Dict +unAccum _ STAccum{} = Nothing +foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a +foldl1M _ [] = error "foldl1M: empty list" +foldl1M f (tophead : toptail) = foldM f tophead toptail |