summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs198
1 files changed, 179 insertions, 19 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 8728ec0..f58cefb 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -9,15 +9,16 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE TupleSections #-}
module Interpreter (
interpret,
interpret',
Value,
) where
-import Control.Monad (foldM)
+import Control.Monad (foldM, join)
import Data.Int (Int64)
-import Data.Proxy
+import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
import Array
@@ -25,9 +26,10 @@ import AST
import CHAD.Types
import Data
import Interpreter.Rep
+import Data.Bifunctor (first)
-newtype AcM s a = AcM (IO a)
+newtype AcM s a = AcM { unAcM :: IO a }
deriving newtype (Functor, Applicative, Monad)
runAcM :: (forall s. AcM s a) -> a
@@ -53,9 +55,9 @@ interpret' env = \case
ECase _ e a b -> interpret' env e >>= \case
Left x -> interpret' (Value x `SCons` env) a
Right y -> interpret' (Value y `SCons` env) b
- ENothing _ _ -> _
- EJust _ _ -> _
- EMaybe _ _ _ _ -> _
+ ENothing _ _ -> return Nothing
+ EJust _ e -> Just <$> interpret' env e
+ EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
EConstArr _ _ _ v -> return v
EBuild1 _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
@@ -88,19 +90,20 @@ interpret' env = \case
EOp _ op e -> interpretOp op <$> interpret' env e
EWith e1 e2 -> do
initval <- interpret' env e1
- withAccum (typeOf e1) initval $ \accum ->
+ withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
EAccum i e1 e2 e3 -> do
+ let STAccum t = typeOf e3
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAdd accum i idx val
+ accumAddSparse t i accum idx val
EZero t -> do
- return $ makeZero t
+ return $ zeroD2 t
EPlus t a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ makePlus t a' b'
+ return $ addD2s t a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -114,8 +117,8 @@ interpretOp op arg = case op of
ONot -> not arg
OIf -> if arg then Left () else Right ()
-makeZero :: STy t -> Rep (D2 t)
-makeZero typ = case typ of
+zeroD2 :: STy t -> Rep (D2 t)
+zeroD2 typ = case typ of
STNil -> ()
STPair _ _ -> Left ()
STEither _ _ -> Left ()
@@ -129,25 +132,29 @@ makeZero typ = case typ of
STBool -> ()
STAccum{} -> error "Zero of Accum"
-makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
-makePlus typ a b = case typ of
+addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
+addD2s typ a b = case typ of
STNil -> ()
STPair t1 t2 -> case (a, b) of
(Left (), _) -> b
(_, Left ()) -> a
- (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2)
+ (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2)
STEither t1 t2 -> case (a, b) of
(Left (), _) -> b
(_, Left ()) -> a
- (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y))
- (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y))
+ (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y))
+ (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y))
_ -> error "Plus of inconsistent Eithers"
+ STMaybe t -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just x, Just y) -> Just (addD2s t x y)
STArr _ t ->
let sh1 = arrayShape a
sh2 = arrayShape b
in if | shapeSize sh1 == 0 -> b
| shapeSize sh2 == 0 -> a
- | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i))
| otherwise -> error "Plus of inconsistently shaped arrays"
STScal sty -> case sty of
STI32 -> ()
@@ -157,6 +164,159 @@ makePlus typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
+withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+withAccum t _ initval f = AcM $ do
+ accum <- newAcSparse t initval
+ out <- case f accum of AcM m -> m
+ val <- readAcSparse t accum
+ return (out, val)
+
+newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
+newAcSparse typ val = case typ of
+ STNil -> return ()
+ STPair{} -> newIORef =<<newAcDense typ val
+ STMaybe t -> newIORef =<< traverse (newAcDense t) val
+ STArr _ t -> newIORef =<< traverse (newAcSparse t) val
+ STScal{} -> newIORef val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+newAcDense :: STy t -> Rep t -> IO (RepAcDense t)
+newAcDense typ val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val)
+ STEither t1 t2 -> case val of
+ Left x -> Left <$> newAcSparse t1 x
+ Right y -> Right <$> newAcSparse t2 y
+ STMaybe t -> traverse (newAcSparse t) val
+ STArr _ t -> traverse (newAcSparse t) val
+ STScal{} -> return val
+ STAccum{} -> error "Nested accumulators"
+
+readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
+readAcSparse typ val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> do
+ (a, b) <- readIORef val
+ (,) <$> readAcSparse t1 a <*> readAcSparse t2 b
+ STMaybe t -> traverse (readAcDense t) =<< readIORef val
+ STArr _ t -> traverse (readAcSparse t) =<< readIORef val
+ STScal{} -> readIORef val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
+readAcDense typ val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
+ STEither t1 t2 -> case val of
+ Left x -> Left <$> readAcSparse t1 x
+ Right y -> Right <$> readAcSparse t2 y
+ STMaybe t -> traverse (readAcSparse t) val
+ STArr _ t -> traverse (readAcSparse t) val
+ STScal{} -> return val
+ STAccum{} -> error "Nested accumulators"
+
+accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAddSparse typ SZ ref () val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> AcM $ do
+ (r1, r2) <- readIORef ref
+ unAcM $ accumAddSparse t1 SZ r1 () (fst val)
+ unAcM $ accumAddSparse t2 SZ r2 () (snd val)
+ STMaybe t ->
+ join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of
+ (Nothing, _) -> (ac, _)
+ (Just{}, Nothing) -> (ac, return ())
+ (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val')
+ STArr _ t -> _ ref val
+ STScal{} -> _ ref val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+accumAddSparse typ (SS dep) ref idx val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> _ ref idx val
+ STMaybe t -> _ ref idx val
+ STArr _ t -> _ ref idx val
+ STScal{} -> _ ref idx val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
+accumAddDense = _
+
+-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ())
+-- accumAddVal typ SZ ac () val = case typ of
+-- STNil -> ((), return ())
+-- STPair t1 t2 ->
+-- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val)
+-- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val)
+-- in ((ac1', ac2'), m1 >> m2)
+-- STMaybe t -> case t of
+-- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val)
+-- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def
+-- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ())
+-- def = (ac, accumAddValM t ac val)
+-- STArr n t
+-- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val
+-- STEither{} -> error "Bare Either in accumulator"
+-- _ -> _
+-- accumAddVal typ (SS dep) ac idx val = case typ of
+-- STNil -> ((), return ())
+-- STPair t1 t2 ->
+-- case (idx, val) of
+-- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val'
+-- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val'
+-- _ -> error "Inconsistent idx and val in accumulator add operation"
+-- _ -> _
+
+-- accumAddValME :: STy a -> STy b
+-- -> IORef (Maybe (Either (RepAc a) (RepAc b)))
+-- -> Maybe (Either (Rep a) (Rep b))
+-- -> AcM s ()
+-- accumAddValME t1 t2 ac val =
+-- case val of
+-- Nothing -> return ()
+-- Just val' ->
+-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of
+-- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val))
+-- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1
+-- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2
+-- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either"
+
+-- initAccumOrTryAgainME :: STy a -> STy b
+-- -> IORef (Maybe (Either (RepAc a) (RepAc b)))
+-- -> Either (Rep a) (Rep b)
+-- -> IO ()
+-- -> IO ()
+-- initAccumOrTryAgainME t1 t2 ac val onRace = do
+-- newContents <- case val of Left x -> Left <$> makeRepAc t1 x
+-- Right y -> Right <$> makeRepAc t2 y
+-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
+-- value@Just{} -> (value, onRace))
+
+-- accumAddValM :: STy t
+-- -> IORef (Maybe (RepAc t))
+-- -> Maybe (Rep t)
+-- -> AcM s ()
+-- accumAddValM t ac val =
+-- case val of
+-- Nothing -> return ()
+-- Just val' ->
+-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of
+-- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val))
+-- Just x -> first Just $ accumAddVal t SZ x () val'
+
+-- initAccumOrTryAgainM :: STy t
+-- -> IORef (Maybe (RepAc t))
+-- -> Rep t
+-- -> IO ()
+-- -> IO ()
+-- initAccumOrTryAgainM t ac val onRace = do
+-- newContents <- makeRepAc t val
+-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
+-- value@Just{} -> (value, onRace))
+
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
@@ -166,7 +326,7 @@ numericIsNum STF64 = id
unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _ SZ _ = nil
-unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
+unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
tupRepIdx :: (forall m. f (S m) -> (f m, Int))
-> SNat n -> f n -> Rep (Tup (Replicate n TIx))