diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 142 |
1 files changed, 86 insertions, 56 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 7ffb14c..8728ec0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,39 +1,44 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Interpreter ( interpret, interpret', Value, - NoAccum(..), - unAccum, ) where +import Control.Monad (foldM) import Data.Int (Int64) import Data.Proxy +import System.IO.Unsafe (unsafePerformIO) +import Array import AST +import CHAD.Types import Data -import Array -import Interpreter.Accum import Interpreter.Rep -import Control.Monad (foldM) -interpret :: NoAccum t => Ex '[] t -> Rep t -interpret e = runAcM (go e) - where - go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) - go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' +newtype AcM s a = AcM (IO a) + deriving newtype (Functor, Applicative, Monad) -newtype Value s t = Value (Rep' s t) +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m -interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) +interpret :: Ex '[] t -> Rep t +interpret e = runAcM (interpret' SNil e) + +newtype Value t = Value (Rep t) + +interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do @@ -48,14 +53,17 @@ interpret' env = \case ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b + ENothing _ _ -> _ + EJust _ _ -> _ + EMaybe _ _ _ _ -> _ EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arrayGenerateLinM (ShNil `ShCons` n) (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) EBuild _ dim a b -> do - sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) EFold1Inner _ a b -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a arr <- interpret' env b @@ -75,9 +83,9 @@ interpret' env = \case EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) - EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e - EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e + EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) initval $ \accum -> @@ -87,10 +95,16 @@ interpret' env = \case val <- interpret' env e2 accum <- interpret' env e3 accumAdd accum i idx val + EZero t -> do + return $ makeZero t + EPlus t a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ makePlus t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s -interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t -interpretOp _ op arg = case op of +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of OAdd st -> numericIsNum st $ uncurry (+) arg OMul st -> numericIsNum st $ uncurry (*) arg ONeg st -> numericIsNum st $ negate arg @@ -100,23 +114,66 @@ interpretOp _ op arg = case op of ONot -> not arg OIf -> if arg then Left () else Right () +makeZero :: STy t -> Rep (D2 t) +makeZero typ = case typ of + STNil -> () + STPair _ _ -> Left () + STEither _ _ -> Left () + STMaybe _ -> Nothing + STArr n _ -> emptyArray n + STScal sty -> case sty of + STI32 -> () + STI64 -> () + STF32 -> 0.0 + STF64 -> 0.0 + STBool -> () + STAccum{} -> error "Zero of Accum" + +makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) +makePlus typ a b = case typ of + STNil -> () + STPair t1 t2 -> case (a, b) of + (Left (), _) -> b + (_, Left ()) -> a + (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) + STEither t1 t2 -> case (a, b) of + (Left (), _) -> b + (_, Left ()) -> a + (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) + (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) + _ -> error "Plus of inconsistent Eithers" + STArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + STScal sty -> case sty of + STI32 -> () + STI64 -> () + STF32 -> a + b + STF64 -> a + b + STBool -> () + STAccum{} -> error "Plus of Accum" + numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id -unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) - -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n -unTupRepIdx _ nil _ SZ _ = nil -unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _ SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i -tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) - -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) -tupRepIdx _ _ SZ _ = () -tupRepIdx p uncons (SS n) tup = +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _ SZ _ = () +tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup - in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) + in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i) ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) @@ -124,33 +181,6 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) -class NoAccum t where - noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t -instance NoAccum TNil where - noAccum _ _ = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where - noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where - noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance NoAccum t => NoAccum (TArr n t) where - noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl -instance NoAccum (TScal t) where - noAccum _ _ = Refl - -unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) -unAccum _ STNil = Just Dict -unAccum p (STPair t1 t2) - | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict - | otherwise = Nothing -unAccum p (STEither t1 t2) - | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict - | otherwise = Nothing -unAccum p (STArr _ t) - | Just Dict <- unAccum p t = Just Dict - | otherwise = Nothing -unAccum _ STScal{} = Just Dict -unAccum _ STAccum{} = Nothing - foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a foldl1M _ [] = error "foldl1M: empty list" foldl1M f (tophead : toptail) = foldM f tophead toptail |