summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Array.hs5
-rw-r--r--src/Interpreter.hs198
-rw-r--r--src/Interpreter/Rep.hs37
3 files changed, 203 insertions, 37 deletions
diff --git a/src/Array.hs b/src/Array.hs
index 0d585a9..d7dadbf 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -1,6 +1,7 @@
-{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
module Array where
@@ -47,7 +48,7 @@ emptyShape (SS m) = emptyShape m `ShCons` 0
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
- deriving (Show)
+ deriving (Show, Functor, Foldable, Traversable)
arrayShape :: Array n t -> Shape n
arrayShape (Array sh _) = sh
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 8728ec0..f58cefb 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -9,15 +9,16 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE TupleSections #-}
module Interpreter (
interpret,
interpret',
Value,
) where
-import Control.Monad (foldM)
+import Control.Monad (foldM, join)
import Data.Int (Int64)
-import Data.Proxy
+import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
import Array
@@ -25,9 +26,10 @@ import AST
import CHAD.Types
import Data
import Interpreter.Rep
+import Data.Bifunctor (first)
-newtype AcM s a = AcM (IO a)
+newtype AcM s a = AcM { unAcM :: IO a }
deriving newtype (Functor, Applicative, Monad)
runAcM :: (forall s. AcM s a) -> a
@@ -53,9 +55,9 @@ interpret' env = \case
ECase _ e a b -> interpret' env e >>= \case
Left x -> interpret' (Value x `SCons` env) a
Right y -> interpret' (Value y `SCons` env) b
- ENothing _ _ -> _
- EJust _ _ -> _
- EMaybe _ _ _ _ -> _
+ ENothing _ _ -> return Nothing
+ EJust _ e -> Just <$> interpret' env e
+ EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
EConstArr _ _ _ v -> return v
EBuild1 _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
@@ -88,19 +90,20 @@ interpret' env = \case
EOp _ op e -> interpretOp op <$> interpret' env e
EWith e1 e2 -> do
initval <- interpret' env e1
- withAccum (typeOf e1) initval $ \accum ->
+ withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
EAccum i e1 e2 e3 -> do
+ let STAccum t = typeOf e3
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAdd accum i idx val
+ accumAddSparse t i accum idx val
EZero t -> do
- return $ makeZero t
+ return $ zeroD2 t
EPlus t a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ makePlus t a' b'
+ return $ addD2s t a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -114,8 +117,8 @@ interpretOp op arg = case op of
ONot -> not arg
OIf -> if arg then Left () else Right ()
-makeZero :: STy t -> Rep (D2 t)
-makeZero typ = case typ of
+zeroD2 :: STy t -> Rep (D2 t)
+zeroD2 typ = case typ of
STNil -> ()
STPair _ _ -> Left ()
STEither _ _ -> Left ()
@@ -129,25 +132,29 @@ makeZero typ = case typ of
STBool -> ()
STAccum{} -> error "Zero of Accum"
-makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
-makePlus typ a b = case typ of
+addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
+addD2s typ a b = case typ of
STNil -> ()
STPair t1 t2 -> case (a, b) of
(Left (), _) -> b
(_, Left ()) -> a
- (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2)
+ (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2)
STEither t1 t2 -> case (a, b) of
(Left (), _) -> b
(_, Left ()) -> a
- (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y))
- (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y))
+ (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y))
+ (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y))
_ -> error "Plus of inconsistent Eithers"
+ STMaybe t -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just x, Just y) -> Just (addD2s t x y)
STArr _ t ->
let sh1 = arrayShape a
sh2 = arrayShape b
in if | shapeSize sh1 == 0 -> b
| shapeSize sh2 == 0 -> a
- | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i))
| otherwise -> error "Plus of inconsistently shaped arrays"
STScal sty -> case sty of
STI32 -> ()
@@ -157,6 +164,159 @@ makePlus typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
+withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+withAccum t _ initval f = AcM $ do
+ accum <- newAcSparse t initval
+ out <- case f accum of AcM m -> m
+ val <- readAcSparse t accum
+ return (out, val)
+
+newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
+newAcSparse typ val = case typ of
+ STNil -> return ()
+ STPair{} -> newIORef =<<newAcDense typ val
+ STMaybe t -> newIORef =<< traverse (newAcDense t) val
+ STArr _ t -> newIORef =<< traverse (newAcSparse t) val
+ STScal{} -> newIORef val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+newAcDense :: STy t -> Rep t -> IO (RepAcDense t)
+newAcDense typ val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val)
+ STEither t1 t2 -> case val of
+ Left x -> Left <$> newAcSparse t1 x
+ Right y -> Right <$> newAcSparse t2 y
+ STMaybe t -> traverse (newAcSparse t) val
+ STArr _ t -> traverse (newAcSparse t) val
+ STScal{} -> return val
+ STAccum{} -> error "Nested accumulators"
+
+readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
+readAcSparse typ val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> do
+ (a, b) <- readIORef val
+ (,) <$> readAcSparse t1 a <*> readAcSparse t2 b
+ STMaybe t -> traverse (readAcDense t) =<< readIORef val
+ STArr _ t -> traverse (readAcSparse t) =<< readIORef val
+ STScal{} -> readIORef val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
+readAcDense typ val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
+ STEither t1 t2 -> case val of
+ Left x -> Left <$> readAcSparse t1 x
+ Right y -> Right <$> readAcSparse t2 y
+ STMaybe t -> traverse (readAcSparse t) val
+ STArr _ t -> traverse (readAcSparse t) val
+ STScal{} -> return val
+ STAccum{} -> error "Nested accumulators"
+
+accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAddSparse typ SZ ref () val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> AcM $ do
+ (r1, r2) <- readIORef ref
+ unAcM $ accumAddSparse t1 SZ r1 () (fst val)
+ unAcM $ accumAddSparse t2 SZ r2 () (snd val)
+ STMaybe t ->
+ join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of
+ (Nothing, _) -> (ac, _)
+ (Just{}, Nothing) -> (ac, return ())
+ (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val')
+ STArr _ t -> _ ref val
+ STScal{} -> _ ref val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+accumAddSparse typ (SS dep) ref idx val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> _ ref idx val
+ STMaybe t -> _ ref idx val
+ STArr _ t -> _ ref idx val
+ STScal{} -> _ ref idx val
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
+accumAddDense = _
+
+-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ())
+-- accumAddVal typ SZ ac () val = case typ of
+-- STNil -> ((), return ())
+-- STPair t1 t2 ->
+-- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val)
+-- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val)
+-- in ((ac1', ac2'), m1 >> m2)
+-- STMaybe t -> case t of
+-- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val)
+-- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def
+-- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ())
+-- def = (ac, accumAddValM t ac val)
+-- STArr n t
+-- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val
+-- STEither{} -> error "Bare Either in accumulator"
+-- _ -> _
+-- accumAddVal typ (SS dep) ac idx val = case typ of
+-- STNil -> ((), return ())
+-- STPair t1 t2 ->
+-- case (idx, val) of
+-- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val'
+-- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val'
+-- _ -> error "Inconsistent idx and val in accumulator add operation"
+-- _ -> _
+
+-- accumAddValME :: STy a -> STy b
+-- -> IORef (Maybe (Either (RepAc a) (RepAc b)))
+-- -> Maybe (Either (Rep a) (Rep b))
+-- -> AcM s ()
+-- accumAddValME t1 t2 ac val =
+-- case val of
+-- Nothing -> return ()
+-- Just val' ->
+-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of
+-- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val))
+-- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1
+-- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2
+-- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either"
+
+-- initAccumOrTryAgainME :: STy a -> STy b
+-- -> IORef (Maybe (Either (RepAc a) (RepAc b)))
+-- -> Either (Rep a) (Rep b)
+-- -> IO ()
+-- -> IO ()
+-- initAccumOrTryAgainME t1 t2 ac val onRace = do
+-- newContents <- case val of Left x -> Left <$> makeRepAc t1 x
+-- Right y -> Right <$> makeRepAc t2 y
+-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
+-- value@Just{} -> (value, onRace))
+
+-- accumAddValM :: STy t
+-- -> IORef (Maybe (RepAc t))
+-- -> Maybe (Rep t)
+-- -> AcM s ()
+-- accumAddValM t ac val =
+-- case val of
+-- Nothing -> return ()
+-- Just val' ->
+-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of
+-- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val))
+-- Just x -> first Just $ accumAddVal t SZ x () val'
+
+-- initAccumOrTryAgainM :: STy t
+-- -> IORef (Maybe (RepAc t))
+-- -> Rep t
+-- -> IO ()
+-- -> IO ()
+-- initAccumOrTryAgainM t ac val onRace = do
+-- newContents <- makeRepAc t val
+-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
+-- value@Just{} -> (value, onRace))
+
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
@@ -166,7 +326,7 @@ numericIsNum STF64 = id
unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _ SZ _ = nil
-unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
+unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
tupRepIdx :: (forall m. f (S m) -> (f m, Int))
-> SNat n -> f n -> Rep (Tup (Replicate n TIx))
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 7add442..680196c 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -1,9 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
import Data.IORef
-import qualified Data.Vector.Mutable as MV
import GHC.TypeError
import Array
@@ -17,19 +17,24 @@ type family Rep t where
Rep (TMaybe t) = Maybe (Rep t)
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
- Rep (TAccum t) = IORef (RepAc t)
+ Rep (TAccum t) = RepAcSparse t
-type family RepAc t where
- RepAc TNil = ()
- RepAc (TPair a b) = (RepAc a, RepAc b)
- -- This is annoying when working with values of type 'RepAc t', because
- -- failing a pattern match does not generate negative type information.
- -- However, it works, saves us from having to defining a LEither type
- -- first-class in the type system with
- -- Rep (LEither a b) = Maybe (Either a b)
- -- and it's not even incorrect, in a way.
- RepAc (TMaybe (TEither a b)) = IORef (Maybe (Either (RepAc a) (RepAc b)))
- RepAc (TMaybe t) = IORef (Maybe (RepAc t))
- RepAc (TArr n t) = (Shape n, MV.IOVector (RepAc t))
- RepAc (TScal sty) = IORef (ScalRep sty)
- RepAc (TAccum t) = TypeError (Text "Nested accumulators")
+-- Mutable, and has an O(1) zero.
+type family RepAcSparse t where
+ RepAcSparse TNil = ()
+ RepAcSparse (TPair a b) = IORef (RepAcDense (TPair a b))
+ RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid")
+ RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t)) -- allow the value to be dense, because the Maybe's zero can be used for the contents
+ RepAcSparse (TArr n t) = IORef (RepAcDense (TArr n t)) -- empty array is zero
+ RepAcSparse (TScal sty) = IORef (ScalRep sty)
+ RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
+
+-- Immutable, and does not necessarily have a zero.
+type family RepAcDense t where
+ RepAcDense TNil = ()
+ RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
+ RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b)
+ RepAcDense (TMaybe t) = Maybe (RepAcSparse t)
+ RepAcDense (TArr n t) = Array n (RepAcSparse t)
+ RepAcDense (TScal sty) = ScalRep sty
+ RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators")