summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs171
1 files changed, 149 insertions, 22 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 62160aa..d2b8074 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,15 +1,17 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE TupleSections #-}
module Interpreter (
interpret,
interpret',
@@ -17,6 +19,7 @@ module Interpreter (
) where
import Control.Monad (foldM, join)
+import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
@@ -166,32 +169,141 @@ addD2s typ a b = case typ of
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- accum <- newAcSparse t initval
+ accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
return (out, val)
-newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
-newAcSparse typ val = case typ of
+newAcZero :: STy t -> IO (RepAcSparse t)
+newAcZero = \case
+ STNil -> return ()
+ STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2
+ STMaybe _ -> newIORef Nothing
+ STArr n _ -> newIORef (emptyArray n)
+ STScal sty -> case sty of
+ STI32 -> newIORef 0
+ STI64 -> newIORef 0
+ STF32 -> newIORef 0.0
+ STF64 -> newIORef 0.0
+ STBool -> error "Accumulator of Bool"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+-- | Inverted index: the outermost index is at the /outside/ of this list.
+data PartialInvIndex n m where
+ PIIxEnd :: PartialInvIndex m m
+ PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m
+
+-- | Inverted shape: the outermost dimension is at the /outside/ of this list.
+data PartialInvShape n m where
+ PIShEnd :: PartialInvShape m m
+ PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m
+
+-- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list.
+data Inverted (f :: Nat -> Type) n where
+ InvNil :: Inverted f Z
+ InvCons :: Int -> Inverted f n -> Inverted f (S n)
+
+type InvShape = Inverted Shape
+type InvIndex = Inverted Index
+
+pattern IIxNil :: () => n ~ Z => InvIndex n
+pattern IIxNil = InvNil
+pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn
+pattern IIxCons i ix = InvCons i ix
+{-# COMPLETE IIxNil, IIxCons #-}
+
+pattern IShNil :: () => n ~ Z => InvShape Z
+pattern IShNil = InvNil
+pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn
+pattern IShCons n sh = InvCons n sh
+{-# COMPLETE IShNil, IShCons #-}
+
+class Shapey f where
+ shapeyNil :: f Z
+ shapeyCons :: f n -> Int -> f (S n)
+ shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r
+instance Shapey Index where
+ shapeyNil = IxNil
+ shapeyCons = IxCons
+ shapeyCase IxNil k0 _ = k0
+ shapeyCase (IxCons idx i) _ k1 = k1 idx i
+instance Shapey Shape where
+ shapeyNil = ShNil
+ shapeyCons = ShCons
+ shapeyCase ShNil k0 _ = k0
+ shapeyCase (ShCons sh n) _ k1 = k1 sh n
+
+invert :: forall f n. Shapey f => f n -> Inverted f n
+invert | Refl <- lemPlusZero @n = flip go shapeyNil
+ where
+ go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m)
+ go sh ish = shapeyCase sh
+ ish
+ (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish))
+
+uninvert :: forall f n. Shapey f => Inverted f n -> f n
+uninvert = go shapeyNil
+ where
+ go :: forall n' m. f n' -> Inverted f m -> f (n' + m)
+ go sh InvNil | Refl <- lemPlusZero @n' = sh
+ go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish
+
+piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m)
+piindexMatch PIIxEnd ix = Just ix
+piindexMatch (PIIxCons i pix) (IIxCons i' ix)
+ | i == i' = piindexMatch pix ix
+ | otherwise = Nothing
+
+newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse typ SZ () val = case typ of
STNil -> return ()
- STPair{} -> newIORef =<< newAcDense typ val
- STMaybe t -> newIORef =<< traverse (newAcDense t) val
- STArr _ t -> newIORef =<< traverse (newAcSparse t) val
+ STPair{} -> newIORef =<< newAcDense typ SZ () val
+ STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val
+ STArr{} -> newIORef =<< newAcDense typ SZ () val
STScal{} -> newIORef val
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
+newAcSparse typ (SS dep) idx val = case typ of
+ STNil -> return ()
+ STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val
+ STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val
+ STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
-newAcDense :: STy t -> Rep t -> IO (RepAcDense t)
-newAcDense typ val = case typ of
+newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense typ SZ () val = case typ of
STNil -> return ()
- STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val)
+ STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STEither t1 t2 -> case val of
- Left x -> Left <$> newAcSparse t1 x
- Right y -> Right <$> newAcSparse t2 y
- STMaybe t -> traverse (newAcSparse t) val
- STArr _ t -> traverse (newAcSparse t) val
+ Left x -> Left <$> newAcSparse t1 SZ () x
+ Right y -> Right <$> newAcSparse t2 SZ () y
+ STMaybe t -> traverse (newAcSparse t SZ ()) val
+ STArr _ t -> traverse (newAcSparse t SZ ()) val
STScal{} -> return val
STAccum{} -> error "Nested accumulators"
+newAcDense typ (SS dep) idx val = case typ of
+ STNil -> return ()
+ STPair{} -> newAcDense typ (SS dep) idx val
+ STMaybe t -> Just <$> newAcSparse t dep idx val
+ STArr dim (t :: STy t) -> do
+ let sh = unTupRepIdx ShNil ShCons dim (fst val)
+ go (SS dep) dim idx (snd val) $ \arr position ->
+ arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
+ Just i' -> return $ arr `arrayIndex` i'
+ Nothing -> newAcZero t)
+ where
+ go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r
+ go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd
+ go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
+ go (SS dep') (SS dim') (i, idx') val' k =
+ go dep' dim' idx' val' $ \arr pish ->
+ k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
readAcSparse typ val = case typ of
@@ -234,7 +346,7 @@ accumAddSparse typ SZ ref () val = case typ of
AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
-- Oops, ref's contents was still sparse. Have to initialise
-- it first, then try again.
- Nothing -> (ac, do newac <- newAcDense t val'
+ Nothing -> (ac, do newac <- newAcDense t SZ () val'
join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
Nothing -> (Just newac, return ())
Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val'))
@@ -246,12 +358,12 @@ accumAddSparse typ SZ ref () val = case typ of
case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
(_, 0) -> return ()
(0, _) -> do
- newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i))
+ newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t SZ () (arrayIndexLinear val i))
join $ atomicModifyIORef' ref $ \refarr ->
if shapeSize (arrayShape refarr) == 0
then (newrefarr, return ())
else -- someone was faster than us in initialising the reference!
- (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start
+ (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up)
_ | arrayShape refs == arrayShape val ->
sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
| i <- [0 .. shapeSize (arrayShape val) - 1]]
@@ -267,10 +379,25 @@ accumAddSparse typ SZ ref () val = case typ of
accumAddSparse typ (SS dep) ref idx val = case typ of
STNil -> return ()
- STPair t1 t2 -> _ ref idx val
- STMaybe t -> _ ref idx val
+ STPair t1 t2 -> AcM $ do
+ (ref1, ref2) <- readIORef ref
+ case (idx, val) of
+ (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val'
+ (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val'
+ _ -> error "Index/value mismatch in pair accumulator add"
+ STMaybe t ->
+ AcM $ join $ atomicModifyIORef' ref $ \case
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (Nothing, do newac <- newAcDense t dep idx val
+ join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
+ Nothing -> (Just newac, return ())
+ Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
STArr _ t -> _ ref idx val
- STScal{} -> _ ref idx val
+ STScal{} -> error "Cannot index into scalar"
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"