diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 198 |
1 files changed, 179 insertions, 19 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 8728ec0..f58cefb 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -9,15 +9,16 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE TupleSections #-} module Interpreter ( interpret, interpret', Value, ) where -import Control.Monad (foldM) +import Control.Monad (foldM, join) import Data.Int (Int64) -import Data.Proxy +import Data.IORef import System.IO.Unsafe (unsafePerformIO) import Array @@ -25,9 +26,10 @@ import AST import CHAD.Types import Data import Interpreter.Rep +import Data.Bifunctor (first) -newtype AcM s a = AcM (IO a) +newtype AcM s a = AcM { unAcM :: IO a } deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a @@ -53,9 +55,9 @@ interpret' env = \case ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b - ENothing _ _ -> _ - EJust _ _ -> _ - EMaybe _ _ _ _ -> _ + ENothing _ _ -> return Nothing + EJust _ e -> Just <$> interpret' env e + EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a @@ -88,19 +90,20 @@ interpret' env = \case EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 - withAccum (typeOf e1) initval $ \accum -> + withAccum (typeOf e1) (typeOf e2) initval $ \accum -> interpret' (Value accum `SCons` env) e2 EAccum i e1 e2 e3 -> do + let STAccum t = typeOf e3 idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAdd accum i idx val + accumAddSparse t i accum idx val EZero t -> do - return $ makeZero t + return $ zeroD2 t EPlus t a b -> do a' <- interpret' env a b' <- interpret' env b - return $ makePlus t a' b' + return $ addD2s t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -114,8 +117,8 @@ interpretOp op arg = case op of ONot -> not arg OIf -> if arg then Left () else Right () -makeZero :: STy t -> Rep (D2 t) -makeZero typ = case typ of +zeroD2 :: STy t -> Rep (D2 t) +zeroD2 typ = case typ of STNil -> () STPair _ _ -> Left () STEither _ _ -> Left () @@ -129,25 +132,29 @@ makeZero typ = case typ of STBool -> () STAccum{} -> error "Zero of Accum" -makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -makePlus typ a b = case typ of +addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) +addD2s typ a b = case typ of STNil -> () STPair t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a - (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) + (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) STEither t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a - (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) - (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) + (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) + (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) _ -> error "Plus of inconsistent Eithers" + STMaybe t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> Just (addD2s t x y) STArr _ t -> let sh1 = arrayShape a sh2 = arrayShape b in if | shapeSize sh1 == 0 -> b | shapeSize sh2 == 0 -> a - | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i)) | otherwise -> error "Plus of inconsistently shaped arrays" STScal sty -> case sty of STI32 -> () @@ -157,6 +164,159 @@ makePlus typ a b = case typ of STBool -> () STAccum{} -> error "Plus of Accum" +withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum t _ initval f = AcM $ do + accum <- newAcSparse t initval + out <- case f accum of AcM m -> m + val <- readAcSparse t accum + return (out, val) + +newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t) +newAcSparse typ val = case typ of + STNil -> return () + STPair{} -> newIORef =<<newAcDense typ val + STMaybe t -> newIORef =<< traverse (newAcDense t) val + STArr _ t -> newIORef =<< traverse (newAcSparse t) val + STScal{} -> newIORef val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +newAcDense :: STy t -> Rep t -> IO (RepAcDense t) +newAcDense typ val = case typ of + STNil -> return () + STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val) + STEither t1 t2 -> case val of + Left x -> Left <$> newAcSparse t1 x + Right y -> Right <$> newAcSparse t2 y + STMaybe t -> traverse (newAcSparse t) val + STArr _ t -> traverse (newAcSparse t) val + STScal{} -> return val + STAccum{} -> error "Nested accumulators" + +readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) +readAcSparse typ val = case typ of + STNil -> return () + STPair t1 t2 -> do + (a, b) <- readIORef val + (,) <$> readAcSparse t1 a <*> readAcSparse t2 b + STMaybe t -> traverse (readAcDense t) =<< readIORef val + STArr _ t -> traverse (readAcSparse t) =<< readIORef val + STScal{} -> readIORef val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +readAcDense :: STy t -> RepAcDense t -> IO (Rep t) +readAcDense typ val = case typ of + STNil -> return () + STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) + STEither t1 t2 -> case val of + Left x -> Left <$> readAcSparse t1 x + Right y -> Right <$> readAcSparse t2 y + STMaybe t -> traverse (readAcSparse t) val + STArr _ t -> traverse (readAcSparse t) val + STScal{} -> return val + STAccum{} -> error "Nested accumulators" + +accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAddSparse typ SZ ref () val = case typ of + STNil -> return () + STPair t1 t2 -> AcM $ do + (r1, r2) <- readIORef ref + unAcM $ accumAddSparse t1 SZ r1 () (fst val) + unAcM $ accumAddSparse t2 SZ r2 () (snd val) + STMaybe t -> + join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of + (Nothing, _) -> (ac, _) + (Just{}, Nothing) -> (ac, return ()) + (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val') + STArr _ t -> _ ref val + STScal{} -> _ ref val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" +accumAddSparse typ (SS dep) ref idx val = case typ of + STNil -> return () + STPair t1 t2 -> _ ref idx val + STMaybe t -> _ ref idx val + STArr _ t -> _ ref idx val + STScal{} -> _ ref idx val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) +accumAddDense = _ + +-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) +-- accumAddVal typ SZ ac () val = case typ of +-- STNil -> ((), return ()) +-- STPair t1 t2 -> +-- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val) +-- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val) +-- in ((ac1', ac2'), m1 >> m2) +-- STMaybe t -> case t of +-- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val) +-- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def +-- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ()) +-- def = (ac, accumAddValM t ac val) +-- STArr n t +-- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val +-- STEither{} -> error "Bare Either in accumulator" +-- _ -> _ +-- accumAddVal typ (SS dep) ac idx val = case typ of +-- STNil -> ((), return ()) +-- STPair t1 t2 -> +-- case (idx, val) of +-- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val' +-- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val' +-- _ -> error "Inconsistent idx and val in accumulator add operation" +-- _ -> _ + +-- accumAddValME :: STy a -> STy b +-- -> IORef (Maybe (Either (RepAc a) (RepAc b))) +-- -> Maybe (Either (Rep a) (Rep b)) +-- -> AcM s () +-- accumAddValME t1 t2 ac val = +-- case val of +-- Nothing -> return () +-- Just val' -> +-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of +-- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val)) +-- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1 +-- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2 +-- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either" + +-- initAccumOrTryAgainME :: STy a -> STy b +-- -> IORef (Maybe (Either (RepAc a) (RepAc b))) +-- -> Either (Rep a) (Rep b) +-- -> IO () +-- -> IO () +-- initAccumOrTryAgainME t1 t2 ac val onRace = do +-- newContents <- case val of Left x -> Left <$> makeRepAc t1 x +-- Right y -> Right <$> makeRepAc t2 y +-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) +-- value@Just{} -> (value, onRace)) + +-- accumAddValM :: STy t +-- -> IORef (Maybe (RepAc t)) +-- -> Maybe (Rep t) +-- -> AcM s () +-- accumAddValM t ac val = +-- case val of +-- Nothing -> return () +-- Just val' -> +-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of +-- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val)) +-- Just x -> first Just $ accumAddVal t SZ x () val' + +-- initAccumOrTryAgainM :: STy t +-- -> IORef (Maybe (RepAc t)) +-- -> Rep t +-- -> IO () +-- -> IO () +-- initAccumOrTryAgainM t ac val onRace = do +-- newContents <- makeRepAc t val +-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) +-- value@Just{} -> (value, onRace)) + numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id @@ -166,7 +326,7 @@ numericIsNum STF64 = id unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n unTupRepIdx nil _ SZ _ = nil -unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i tupRepIdx :: (forall m. f (S m) -> (f m, Int)) -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) |