diff options
-rw-r--r-- | src/Array.hs | 5 | ||||
-rw-r--r-- | src/Interpreter.hs | 198 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 37 |
3 files changed, 203 insertions, 37 deletions
diff --git a/src/Array.hs b/src/Array.hs index 0d585a9..d7dadbf 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -1,6 +1,7 @@ -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} module Array where @@ -47,7 +48,7 @@ emptyShape (SS m) = emptyShape m `ShCons` 0 -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) - deriving (Show) + deriving (Show, Functor, Foldable, Traversable) arrayShape :: Array n t -> Shape n arrayShape (Array sh _) = sh diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 8728ec0..f58cefb 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -9,15 +9,16 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE TupleSections #-} module Interpreter ( interpret, interpret', Value, ) where -import Control.Monad (foldM) +import Control.Monad (foldM, join) import Data.Int (Int64) -import Data.Proxy +import Data.IORef import System.IO.Unsafe (unsafePerformIO) import Array @@ -25,9 +26,10 @@ import AST import CHAD.Types import Data import Interpreter.Rep +import Data.Bifunctor (first) -newtype AcM s a = AcM (IO a) +newtype AcM s a = AcM { unAcM :: IO a } deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a @@ -53,9 +55,9 @@ interpret' env = \case ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b - ENothing _ _ -> _ - EJust _ _ -> _ - EMaybe _ _ _ _ -> _ + ENothing _ _ -> return Nothing + EJust _ e -> Just <$> interpret' env e + EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a @@ -88,19 +90,20 @@ interpret' env = \case EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 - withAccum (typeOf e1) initval $ \accum -> + withAccum (typeOf e1) (typeOf e2) initval $ \accum -> interpret' (Value accum `SCons` env) e2 EAccum i e1 e2 e3 -> do + let STAccum t = typeOf e3 idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAdd accum i idx val + accumAddSparse t i accum idx val EZero t -> do - return $ makeZero t + return $ zeroD2 t EPlus t a b -> do a' <- interpret' env a b' <- interpret' env b - return $ makePlus t a' b' + return $ addD2s t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -114,8 +117,8 @@ interpretOp op arg = case op of ONot -> not arg OIf -> if arg then Left () else Right () -makeZero :: STy t -> Rep (D2 t) -makeZero typ = case typ of +zeroD2 :: STy t -> Rep (D2 t) +zeroD2 typ = case typ of STNil -> () STPair _ _ -> Left () STEither _ _ -> Left () @@ -129,25 +132,29 @@ makeZero typ = case typ of STBool -> () STAccum{} -> error "Zero of Accum" -makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -makePlus typ a b = case typ of +addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) +addD2s typ a b = case typ of STNil -> () STPair t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a - (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) + (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) STEither t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a - (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) - (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) + (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) + (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) _ -> error "Plus of inconsistent Eithers" + STMaybe t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> Just (addD2s t x y) STArr _ t -> let sh1 = arrayShape a sh2 = arrayShape b in if | shapeSize sh1 == 0 -> b | shapeSize sh2 == 0 -> a - | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i)) | otherwise -> error "Plus of inconsistently shaped arrays" STScal sty -> case sty of STI32 -> () @@ -157,6 +164,159 @@ makePlus typ a b = case typ of STBool -> () STAccum{} -> error "Plus of Accum" +withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum t _ initval f = AcM $ do + accum <- newAcSparse t initval + out <- case f accum of AcM m -> m + val <- readAcSparse t accum + return (out, val) + +newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t) +newAcSparse typ val = case typ of + STNil -> return () + STPair{} -> newIORef =<<newAcDense typ val + STMaybe t -> newIORef =<< traverse (newAcDense t) val + STArr _ t -> newIORef =<< traverse (newAcSparse t) val + STScal{} -> newIORef val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +newAcDense :: STy t -> Rep t -> IO (RepAcDense t) +newAcDense typ val = case typ of + STNil -> return () + STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val) + STEither t1 t2 -> case val of + Left x -> Left <$> newAcSparse t1 x + Right y -> Right <$> newAcSparse t2 y + STMaybe t -> traverse (newAcSparse t) val + STArr _ t -> traverse (newAcSparse t) val + STScal{} -> return val + STAccum{} -> error "Nested accumulators" + +readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) +readAcSparse typ val = case typ of + STNil -> return () + STPair t1 t2 -> do + (a, b) <- readIORef val + (,) <$> readAcSparse t1 a <*> readAcSparse t2 b + STMaybe t -> traverse (readAcDense t) =<< readIORef val + STArr _ t -> traverse (readAcSparse t) =<< readIORef val + STScal{} -> readIORef val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +readAcDense :: STy t -> RepAcDense t -> IO (Rep t) +readAcDense typ val = case typ of + STNil -> return () + STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) + STEither t1 t2 -> case val of + Left x -> Left <$> readAcSparse t1 x + Right y -> Right <$> readAcSparse t2 y + STMaybe t -> traverse (readAcSparse t) val + STArr _ t -> traverse (readAcSparse t) val + STScal{} -> return val + STAccum{} -> error "Nested accumulators" + +accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAddSparse typ SZ ref () val = case typ of + STNil -> return () + STPair t1 t2 -> AcM $ do + (r1, r2) <- readIORef ref + unAcM $ accumAddSparse t1 SZ r1 () (fst val) + unAcM $ accumAddSparse t2 SZ r2 () (snd val) + STMaybe t -> + join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of + (Nothing, _) -> (ac, _) + (Just{}, Nothing) -> (ac, return ()) + (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val') + STArr _ t -> _ ref val + STScal{} -> _ ref val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" +accumAddSparse typ (SS dep) ref idx val = case typ of + STNil -> return () + STPair t1 t2 -> _ ref idx val + STMaybe t -> _ ref idx val + STArr _ t -> _ ref idx val + STScal{} -> _ ref idx val + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) +accumAddDense = _ + +-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) +-- accumAddVal typ SZ ac () val = case typ of +-- STNil -> ((), return ()) +-- STPair t1 t2 -> +-- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val) +-- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val) +-- in ((ac1', ac2'), m1 >> m2) +-- STMaybe t -> case t of +-- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val) +-- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def +-- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ()) +-- def = (ac, accumAddValM t ac val) +-- STArr n t +-- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val +-- STEither{} -> error "Bare Either in accumulator" +-- _ -> _ +-- accumAddVal typ (SS dep) ac idx val = case typ of +-- STNil -> ((), return ()) +-- STPair t1 t2 -> +-- case (idx, val) of +-- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val' +-- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val' +-- _ -> error "Inconsistent idx and val in accumulator add operation" +-- _ -> _ + +-- accumAddValME :: STy a -> STy b +-- -> IORef (Maybe (Either (RepAc a) (RepAc b))) +-- -> Maybe (Either (Rep a) (Rep b)) +-- -> AcM s () +-- accumAddValME t1 t2 ac val = +-- case val of +-- Nothing -> return () +-- Just val' -> +-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of +-- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val)) +-- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1 +-- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2 +-- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either" + +-- initAccumOrTryAgainME :: STy a -> STy b +-- -> IORef (Maybe (Either (RepAc a) (RepAc b))) +-- -> Either (Rep a) (Rep b) +-- -> IO () +-- -> IO () +-- initAccumOrTryAgainME t1 t2 ac val onRace = do +-- newContents <- case val of Left x -> Left <$> makeRepAc t1 x +-- Right y -> Right <$> makeRepAc t2 y +-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) +-- value@Just{} -> (value, onRace)) + +-- accumAddValM :: STy t +-- -> IORef (Maybe (RepAc t)) +-- -> Maybe (Rep t) +-- -> AcM s () +-- accumAddValM t ac val = +-- case val of +-- Nothing -> return () +-- Just val' -> +-- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of +-- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val)) +-- Just x -> first Just $ accumAddVal t SZ x () val' + +-- initAccumOrTryAgainM :: STy t +-- -> IORef (Maybe (RepAc t)) +-- -> Rep t +-- -> IO () +-- -> IO () +-- initAccumOrTryAgainM t ac val onRace = do +-- newContents <- makeRepAc t val +-- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) +-- value@Just{} -> (value, onRace)) + numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id @@ -166,7 +326,7 @@ numericIsNum STF64 = id unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n unTupRepIdx nil _ SZ _ = nil -unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i tupRepIdx :: (forall m. f (S m) -> (f m, Int)) -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 7add442..680196c 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -1,9 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where import Data.IORef -import qualified Data.Vector.Mutable as MV import GHC.TypeError import Array @@ -17,19 +17,24 @@ type family Rep t where Rep (TMaybe t) = Maybe (Rep t) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty - Rep (TAccum t) = IORef (RepAc t) + Rep (TAccum t) = RepAcSparse t -type family RepAc t where - RepAc TNil = () - RepAc (TPair a b) = (RepAc a, RepAc b) - -- This is annoying when working with values of type 'RepAc t', because - -- failing a pattern match does not generate negative type information. - -- However, it works, saves us from having to defining a LEither type - -- first-class in the type system with - -- Rep (LEither a b) = Maybe (Either a b) - -- and it's not even incorrect, in a way. - RepAc (TMaybe (TEither a b)) = IORef (Maybe (Either (RepAc a) (RepAc b))) - RepAc (TMaybe t) = IORef (Maybe (RepAc t)) - RepAc (TArr n t) = (Shape n, MV.IOVector (RepAc t)) - RepAc (TScal sty) = IORef (ScalRep sty) - RepAc (TAccum t) = TypeError (Text "Nested accumulators") +-- Mutable, and has an O(1) zero. +type family RepAcSparse t where + RepAcSparse TNil = () + RepAcSparse (TPair a b) = IORef (RepAcDense (TPair a b)) + RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid") + RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t)) -- allow the value to be dense, because the Maybe's zero can be used for the contents + RepAcSparse (TArr n t) = IORef (RepAcDense (TArr n t)) -- empty array is zero + RepAcSparse (TScal sty) = IORef (ScalRep sty) + RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") + +-- Immutable, and does not necessarily have a zero. +type family RepAcDense t where + RepAcDense TNil = () + RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b) + RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b) + RepAcDense (TMaybe t) = Maybe (RepAcSparse t) + RepAcDense (TArr n t) = Array n (RepAcSparse t) + RepAcDense (TScal sty) = ScalRep sty + RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") |