{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} module Interpreter ( interpret, interpret', Value, NoAccum(..), unAccum, ) where import Data.Int (Int64) import Data.Proxy import AST import Data import Array import Interpreter.Accum import Interpreter.Rep import Control.Monad (foldM) interpret :: NoAccum t => Ex '[] t -> Rep t interpret e = runAcM (go e) where go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' newtype Value s t = Value (Rep' s t) interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a interpret' (Value x `SCons` env) b EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e ESnd _ e -> snd <$> interpret' env e ENil _ -> return () EInl _ _ e -> Left <$> interpret' env e EInr _ _ e -> Right <$> interpret' env e ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arrayGenerateLinM (ShNil `ShCons` n) (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) EBuild _ dim a b -> do sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) EFold1Inner _ a b -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a arr <- interpret' env b let sh `ShCons` n = arrayShape arr arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e sh `ShCons` n = arrayShape arr numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) EReplicate1Inner _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arr <- interpret' env b let sh = arrayShape arr arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx)) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) initval $ \accum -> interpret' (Value accum `SCons` env) e2 EAccum i e1 e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 accumAdd accum i idx val EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t interpretOp _ op arg = case op of OAdd st -> numericIsNum st $ uncurry (+) arg OMul st -> numericIsNum st $ uncurry (*) arg ONeg st -> numericIsNum st $ negate arg OLt st -> numericIsNum st $ uncurry (<) arg OLe st -> numericIsNum st $ uncurry (<=) arg OEq st -> numericIsNum st $ uncurry (==) arg ONot -> not arg OIf -> if arg then Left () else Right () numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n unTupRepIdx _ nil _ SZ _ = nil unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) tupRepIdx _ _ SZ _ = () tupRepIdx p uncons (SS n) tup = let (tup', i) = uncons tup in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) class NoAccum t where noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t instance NoAccum TNil where noAccum _ _ = Refl instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl instance NoAccum t => NoAccum (TArr n t) where noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl instance NoAccum (TScal t) where noAccum _ _ = Refl unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) unAccum _ STNil = Just Dict unAccum p (STPair t1 t2) | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict | otherwise = Nothing unAccum p (STEither t1 t2) | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict | otherwise = Nothing unAccum p (STArr _ t) | Just Dict <- unAccum p t = Just Dict | otherwise = Nothing unAccum _ STScal{} = Just Dict unAccum _ STAccum{} = Nothing foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a foldl1M _ [] = error "foldl1M: empty list" foldl1M f (tophead : toptail) = foldM f tophead toptail