summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs267
1 files changed, 155 insertions, 112 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index d80a76e..11caac0 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,6 +21,7 @@ module Interpreter (
import Control.Monad (foldM, join, when)
import Data.Bifunctor (bimap)
+import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
import Data.Kind (Type)
@@ -134,26 +135,25 @@ interpret'Rec env = \case
e1' <- interpret' env e1
e2' <- interpret' env e2
interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
- EWith _ e1 e2 -> do
+ EWith _ t e1 e2 -> do
initval <- interpret' env e1
- withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
+ withAccum t (typeOf e2) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
- EAccum _ i e1 e2 e3 -> do
- let STAccum t = typeOf e3
+ EAccum _ t p e1 e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t i accum idx val
+ accumAddSparse t p accum idx val
EZero _ t -> do
return $ zeroD2 t
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
return $ addD2s t a' b'
- EOneHot _ t i a b -> do
+ EOneHot _ t p a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ onehotD2 i t a' b'
+ return $ onehotD2 p t a' b'
EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -230,44 +230,37 @@ addD2s typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
-onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)
-onehotD2 SZ _ () v = v
-onehotD2 _ STNil _ _ = ()
-onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val
-onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2)
-onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val)
-onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
-onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val
-onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val))
-onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val))
-onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
-onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val)
-onehotD2 (SS i ) (STArr n t) idx val = runIdentity $
- onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val
-onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar"
-onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
-
-withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a)
+onehotD2 SAPHere _ _ val = val
+onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b)
+onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val)
+onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val))
+onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
+onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
+onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
+ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
+
+withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
withAccum t _ initval f = AcM $ do
- accum <- newAcSparse t SZ () initval
+ accum <- newAcSparse t SAPHere () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
return (out, val)
-newAcZero :: STy t -> IO (RepAcSparse t)
+newAcZero :: STy t -> IO (RepAc t)
newAcZero = \case
STNil -> return ()
- STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2
+ STPair{} -> newIORef Nothing
+ STEither{} -> newIORef Nothing
STMaybe _ -> newIORef Nothing
STArr n _ -> newIORef (emptyArray n)
STScal sty -> case sty of
- STI32 -> newIORef 0
- STI64 -> newIORef 0
+ STI32 -> return ()
+ STI64 -> return ()
STF32 -> newIORef 0.0
STF64 -> newIORef 0.0
- STBool -> error "Accumulator of Bool"
+ STBool -> return ()
STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-- | Inverted index: the outermost index is at the /outside/ of this list.
data PartialInvIndex n m where
@@ -322,95 +315,144 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
-newAcSparse typ SZ () val = case typ of
- STNil -> return ()
- STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
- STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val
- STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val
- STScal{} -> newIORef val
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-newAcSparse typ (SS dep) idx val = case typ of
- STNil -> return ()
- STPair t1 t2 -> newIORef =<< case (idx, val) of
- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
- _ -> error "Index/value mismatch in newAc pair"
- STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val
- STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
+newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)
+newAcSparse typ prj idx val = case (typ, prj) of
+ (STNil, SAPHere) -> return ()
+ (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
+ (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
+ (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
+ (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
+ (STScal sty, SAPHere) -> case sty of
+ STI32 -> return ()
+ STI64 -> return ()
+ STF32 -> newIORef val
+ STF64 -> newIORef val
+ STBool -> return ()
-newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t))
-newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n
+ (STPair t1 t2, SAPFst prj') ->
+ newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2
+ (STPair t1 t2, SAPSnd prj') ->
+ newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val
-onehotArray :: Monad m
- => STy t
- -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -- ^ the "one"
- -> m v -- ^ generate a zero value for elsewhere
- -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v)
-onehotArray _ mkone _ _ SZ _ val =
- traverse (mkone SZ ()) val
-onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
- let sh = unTupRepIdx ShNil ShCons dim (fst val)
- go mkone dep dim idx (snd val) $ \arr position ->
- arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
- Just i' -> return $ arr `arrayIndex` i'
- Nothing -> mkzero)
- where
- go :: Monad m
- => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v)
- -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
- -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r
- go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd
- go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
- go mk (SS dep') (SS dim') (i, idx') val' k =
- go mk dep' dim' idx' val' $ \arr pish ->
- k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-
-newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
-newAcDense typ SZ () val = case typ of
- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
- STEither t1 t2 -> case val of
- Left x -> Left <$> newAcSparse t1 SZ () x
- Right y -> Right <$> newAcSparse t2 SZ () y
- _ -> error "newAcDense: invalid dense type"
-newAcDense typ (SS dep) idx val = case typ of
- STPair t1 t2 ->
- case (idx, val) of
- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
- _ -> error "Index/value mismatch in newAc pair"
- STEither t1 t2 ->
- case (idx, val) of
- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
- _ -> error "Index/value mismatch in newAc either"
- _ -> error "newAcDense: invalid dense type"
+ (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
+ (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
+
+ (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
+
+ (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val
+
+ (STAccum{}, _) -> error "Accumulators not allowed in source program"
-readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
+newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a))
+newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx
+
+onehotArray :: Monad m
+ => (Rep (AcIdx p a) -> m v) -- ^ the "one"
+ -> m v -- ^ the "zero"
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
+onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =
+ let arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero)
+
+-- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a))
+-- newAcDense typ SZ () val = case typ of
+-- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
+-- STEither t1 t2 -> case val of
+-- Left x -> Left <$> newAcSparse t1 SZ () x
+-- Right y -> Right <$> newAcSparse t2 SZ () y
+-- _ -> error "newAcDense: invalid dense type"
+-- newAcDense typ (SS dep) idx val = case typ of
+-- STPair t1 t2 ->
+-- case (idx, val) of
+-- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
+-- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
+-- _ -> error "Index/value mismatch in newAc pair"
+-- STEither t1 t2 ->
+-- case (idx, val) of
+-- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
+-- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
+-- _ -> error "Index/value mismatch in newAc either"
+-- _ -> error "newAcDense: invalid dense type"
+
+readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))
readAcSparse typ val = case typ of
STNil -> return ()
- STPair t1 t2 -> do
- (a, b) <- readIORef val
- (,) <$> readAcSparse t1 a <*> readAcSparse t2 b
- STMaybe t -> traverse (readAcDense t) =<< readIORef val
+ STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
+ STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
+ STMaybe t -> traverse (readAcSparse t) =<< readIORef val
STArr _ t -> traverse (readAcSparse t) =<< readIORef val
- STScal{} -> readIORef val
+ STScal sty -> case sty of
+ STI32 -> return ()
+ STI64 -> return ()
+ STF32 -> readIORef val
+ STF64 -> readIORef val
+ STBool -> return ()
STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
-readAcDense typ val = case typ of
- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
- STEither t1 t2 -> case val of
- Left x -> Left <$> readAcSparse t1 x
- Right y -> Right <$> readAcSparse t2 y
- _ -> error "readAcDense: invalid dense type"
+-- readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
+-- readAcDense typ val = case typ of
+-- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
+-- STEither t1 t2 -> case val of
+-- Left x -> Left <$> readAcSparse t1 x
+-- Right y -> Right <$> readAcSparse t2 y
+-- _ -> error "readAcDense: invalid dense type"
+
+accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()
+accumAddSparse typ prj ref idx val = case (typ, prj) of
+ (STNil, SAPHere) -> return ()
-accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+ (STPair t1 t2, SAPHere) ->
+ case val of
+ Nothing -> return ()
+ Just (val1, val2) ->
+ AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
+ <*> newAcSparse t2 SAPHere () val2)
+ (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1
+ unAcM $ accumAddSparse t2 SAPHere ac2 () val2)
+ (STPair t1 t2, SAPFst prj') ->
+ AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
+ (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val)
+ (STPair t1 t2, SAPSnd prj') ->
+ AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
+ (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val)
+
+ (STEither t1 t2, SAPHere) -> _ ref val
+ (STEither t1 _, SAPLeft prj') -> _ ref idx val
+ (STEither _ t2, SAPRight prj') -> _ ref idx val
+
+ (STMaybe t1, SAPHere) -> _ ref val
+ (STMaybe t1, SAPJust prj') -> _ ref idx val
+
+ (STArr _ t1, SAPHere) -> _ ref val
+ (STArr n t, SAPArrIdx prj' _) -> _ ref idx val
+
+ (STScal sty, SAPHere) -> AcM $ case sty of
+ STI32 -> return ()
+ STI64 -> return ()
+ STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STBool -> return ()
+
+ (STAccum{}, _) -> error "Accumulators not allowed in source program"
+
+realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO ()
+realiseMaybeSparse ref makeval modifyval =
+ -- Try modifying what's already in ref. The 'join' makes the snd
+ -- of the function's return value a _continuation_ that is run after
+ -- the critical section ends.
+ join $ atomicModifyIORef' ref $ \ac -> case ac of
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (ac, do val <- makeval
+ join $ atomicModifyIORef' ref $ \ac' -> case ac' of
+ Nothing -> (Just val, return ())
+ Just val' -> (ac', modifyval val'))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just val -> (ac, modifyval val)
+
+{-
accumAddSparse typ SZ ref () val = case typ of
STNil -> return ()
STPair t1 t2 -> AcM $ do
@@ -532,6 +574,7 @@ accumAddDense typ (SS dep) ref idx val = case typ of
(Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
_ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
+-}
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r