{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TupleSections #-} module Interpreter ( interpret, interpret', Value, ) where import Control.Monad (foldM, join) import Data.Int (Int64) import Data.IORef import System.IO.Unsafe (unsafePerformIO) import Array import AST import CHAD.Types import Data import Interpreter.Rep import Data.Bifunctor (first) newtype AcM s a = AcM { unAcM :: IO a } deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m interpret :: Ex '[] t -> Rep t interpret e = runAcM (interpret' SNil e) newtype Value t = Value (Rep t) interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a interpret' (Value x `SCons` env) b EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e ESnd _ e -> snd <$> interpret' env e ENil _ -> return () EInl _ _ e -> Left <$> interpret' env e EInr _ _ e -> Right <$> interpret' env e ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b ENothing _ _ -> return Nothing EJust _ e -> Just <$> interpret' env e EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arrayGenerateLinM (ShNil `ShCons` n) (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) EFold1Inner _ a b -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a arr <- interpret' env b let sh `ShCons` n = arrayShape arr arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e sh `ShCons` n = arrayShape arr numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) EReplicate1Inner _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arr <- interpret' env b let sh = arrayShape arr arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx)) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) (typeOf e2) initval $ \accum -> interpret' (Value accum `SCons` env) e2 EAccum i e1 e2 e3 -> do let STAccum t = typeOf e3 idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 accumAddSparse t i accum idx val EZero t -> do return $ zeroD2 t EPlus t a b -> do a' <- interpret' env a b' <- interpret' env b return $ addD2s t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t interpretOp op arg = case op of OAdd st -> numericIsNum st $ uncurry (+) arg OMul st -> numericIsNum st $ uncurry (*) arg ONeg st -> numericIsNum st $ negate arg OLt st -> numericIsNum st $ uncurry (<) arg OLe st -> numericIsNum st $ uncurry (<=) arg OEq st -> numericIsNum st $ uncurry (==) arg ONot -> not arg OIf -> if arg then Left () else Right () zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of STNil -> () STPair _ _ -> Left () STEither _ _ -> Left () STMaybe _ -> Nothing STArr n _ -> emptyArray n STScal sty -> case sty of STI32 -> () STI64 -> () STF32 -> 0.0 STF64 -> 0.0 STBool -> () STAccum{} -> error "Zero of Accum" addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) addD2s typ a b = case typ of STNil -> () STPair t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) STEither t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) _ -> error "Plus of inconsistent Eithers" STMaybe t -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a (Just x, Just y) -> Just (addD2s t x y) STArr _ t -> let sh1 = arrayShape a sh2 = arrayShape b in if | shapeSize sh1 == 0 -> b | shapeSize sh2 == 0 -> a | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i)) | otherwise -> error "Plus of inconsistently shaped arrays" STScal sty -> case sty of STI32 -> () STI64 -> () STF32 -> a + b STF64 -> a + b STBool -> () STAccum{} -> error "Plus of Accum" withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do accum <- newAcSparse t initval out <- case f accum of AcM m -> m val <- readAcSparse t accum return (out, val) newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t) newAcSparse typ val = case typ of STNil -> return () STPair{} -> newIORef =< newIORef =<< traverse (newAcDense t) val STArr _ t -> newIORef =<< traverse (newAcSparse t) val STScal{} -> newIORef val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" newAcDense :: STy t -> Rep t -> IO (RepAcDense t) newAcDense typ val = case typ of STNil -> return () STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val) STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 x Right y -> Right <$> newAcSparse t2 y STMaybe t -> traverse (newAcSparse t) val STArr _ t -> traverse (newAcSparse t) val STScal{} -> return val STAccum{} -> error "Nested accumulators" readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) readAcSparse typ val = case typ of STNil -> return () STPair t1 t2 -> do (a, b) <- readIORef val (,) <$> readAcSparse t1 a <*> readAcSparse t2 b STMaybe t -> traverse (readAcDense t) =<< readIORef val STArr _ t -> traverse (readAcSparse t) =<< readIORef val STScal{} -> readIORef val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" readAcDense :: STy t -> RepAcDense t -> IO (Rep t) readAcDense typ val = case typ of STNil -> return () STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) STEither t1 t2 -> case val of Left x -> Left <$> readAcSparse t1 x Right y -> Right <$> readAcSparse t2 y STMaybe t -> traverse (readAcSparse t) val STArr _ t -> traverse (readAcSparse t) val STScal{} -> return val STAccum{} -> error "Nested accumulators" accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () accumAddSparse typ SZ ref () val = case typ of STNil -> return () STPair t1 t2 -> AcM $ do (r1, r2) <- readIORef ref unAcM $ accumAddSparse t1 SZ r1 () (fst val) unAcM $ accumAddSparse t2 SZ r2 () (snd val) STMaybe t -> join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of (Nothing, _) -> (ac, _) (Just{}, Nothing) -> (ac, return ()) (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val') STArr _ t -> _ ref val STScal{} -> _ ref val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" accumAddSparse typ (SS dep) ref idx val = case typ of STNil -> return () STPair t1 t2 -> _ ref idx val STMaybe t -> _ ref idx val STArr _ t -> _ ref idx val STScal{} -> _ ref idx val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) accumAddDense = _ -- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) -- accumAddVal typ SZ ac () val = case typ of -- STNil -> ((), return ()) -- STPair t1 t2 -> -- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val) -- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val) -- in ((ac1', ac2'), m1 >> m2) -- STMaybe t -> case t of -- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val) -- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def -- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ()) -- def = (ac, accumAddValM t ac val) -- STArr n t -- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val -- STEither{} -> error "Bare Either in accumulator" -- _ -> _ -- accumAddVal typ (SS dep) ac idx val = case typ of -- STNil -> ((), return ()) -- STPair t1 t2 -> -- case (idx, val) of -- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val' -- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val' -- _ -> error "Inconsistent idx and val in accumulator add operation" -- _ -> _ -- accumAddValME :: STy a -> STy b -- -> IORef (Maybe (Either (RepAc a) (RepAc b))) -- -> Maybe (Either (Rep a) (Rep b)) -- -> AcM s () -- accumAddValME t1 t2 ac val = -- case val of -- Nothing -> return () -- Just val' -> -- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of -- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val)) -- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1 -- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2 -- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either" -- initAccumOrTryAgainME :: STy a -> STy b -- -> IORef (Maybe (Either (RepAc a) (RepAc b))) -- -> Either (Rep a) (Rep b) -- -> IO () -- -> IO () -- initAccumOrTryAgainME t1 t2 ac val onRace = do -- newContents <- case val of Left x -> Left <$> makeRepAc t1 x -- Right y -> Right <$> makeRepAc t2 y -- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) -- value@Just{} -> (value, onRace)) -- accumAddValM :: STy t -- -> IORef (Maybe (RepAc t)) -- -> Maybe (Rep t) -- -> AcM s () -- accumAddValM t ac val = -- case val of -- Nothing -> return () -- Just val' -> -- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of -- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val)) -- Just x -> first Just $ accumAddVal t SZ x () val' -- initAccumOrTryAgainM :: STy t -- -> IORef (Maybe (RepAc t)) -- -> Rep t -- -> IO () -- -> IO () -- initAccumOrTryAgainM t ac val onRace = do -- newContents <- makeRepAc t val -- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) -- value@Just{} -> (value, onRace)) numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n unTupRepIdx nil _ SZ _ = nil unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i tupRepIdx :: (forall m. f (S m) -> (f m, Int)) -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) tupRepIdx _ SZ _ = () tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i) ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a foldl1M _ [] = error "foldl1M: empty list" foldl1M f (tophead : toptail) = foldM f tophead toptail