summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs9
-rw-r--r--src/Array.hs9
-rw-r--r--src/Interpreter.hs171
-rw-r--r--src/Interpreter/Rep.hs2
4 files changed, 166 insertions, 25 deletions
diff --git a/src/AST.hs b/src/AST.hs
index ed2039b..8dfea68 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -13,6 +13,7 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module AST (module AST, module AST.Types, module AST.Weaken) where
import Data.Functor.Const
@@ -43,8 +44,12 @@ type family AcVal t i where
AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
AcVal (TMaybe t) (S i) = AcVal t i
- AcVal (TArr Z t) (S i) = AcVal t i
- AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i
+ AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i))
+
+type family AcValArr n t i where
+ AcValArr n t Z = TArr n t
+ AcValArr Z t (S i) = AcVal t i
+ AcValArr (S n) t (S i) = AcValArr n t i
-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
diff --git a/src/Array.hs b/src/Array.hs
index d7dadbf..6473bf0 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -59,6 +59,9 @@ arraySize (Array sh _) = shapeSize sh
emptyArray :: SNat n -> Array n t
emptyArray n = Array (emptyShape n) V.empty
+arrayUnit :: t -> Array Z t
+arrayUnit x = Array ShNil (V.singleton x)
+
arrayIndex :: Array n t -> Index n -> t
arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
@@ -80,6 +83,12 @@ arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
+arrayMap :: (a -> b) -> Array n a -> Array n b
+arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr)
+
+arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b)
+arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr)
+
-- | The Int is the linear index of the value.
traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 62160aa..d2b8074 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,15 +1,17 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE TupleSections #-}
module Interpreter (
interpret,
interpret',
@@ -17,6 +19,7 @@ module Interpreter (
) where
import Control.Monad (foldM, join)
+import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
@@ -166,32 +169,141 @@ addD2s typ a b = case typ of
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- accum <- newAcSparse t initval
+ accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
return (out, val)
-newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
-newAcSparse typ val = case typ of
+newAcZero :: STy t -> IO (RepAcSparse t)
+newAcZero = \case
+ STNil -> return ()
+ STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2
+ STMaybe _ -> newIORef Nothing
+ STArr n _ -> newIORef (emptyArray n)
+ STScal sty -> case sty of
+ STI32 -> newIORef 0
+ STI64 -> newIORef 0
+ STF32 -> newIORef 0.0
+ STF64 -> newIORef 0.0
+ STBool -> error "Accumulator of Bool"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+-- | Inverted index: the outermost index is at the /outside/ of this list.
+data PartialInvIndex n m where
+ PIIxEnd :: PartialInvIndex m m
+ PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m
+
+-- | Inverted shape: the outermost dimension is at the /outside/ of this list.
+data PartialInvShape n m where
+ PIShEnd :: PartialInvShape m m
+ PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m
+
+-- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list.
+data Inverted (f :: Nat -> Type) n where
+ InvNil :: Inverted f Z
+ InvCons :: Int -> Inverted f n -> Inverted f (S n)
+
+type InvShape = Inverted Shape
+type InvIndex = Inverted Index
+
+pattern IIxNil :: () => n ~ Z => InvIndex n
+pattern IIxNil = InvNil
+pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn
+pattern IIxCons i ix = InvCons i ix
+{-# COMPLETE IIxNil, IIxCons #-}
+
+pattern IShNil :: () => n ~ Z => InvShape Z
+pattern IShNil = InvNil
+pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn
+pattern IShCons n sh = InvCons n sh
+{-# COMPLETE IShNil, IShCons #-}
+
+class Shapey f where
+ shapeyNil :: f Z
+ shapeyCons :: f n -> Int -> f (S n)
+ shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r
+instance Shapey Index where
+ shapeyNil = IxNil
+ shapeyCons = IxCons
+ shapeyCase IxNil k0 _ = k0
+ shapeyCase (IxCons idx i) _ k1 = k1 idx i
+instance Shapey Shape where
+ shapeyNil = ShNil
+ shapeyCons = ShCons
+ shapeyCase ShNil k0 _ = k0
+ shapeyCase (ShCons sh n) _ k1 = k1 sh n
+
+invert :: forall f n. Shapey f => f n -> Inverted f n
+invert | Refl <- lemPlusZero @n = flip go shapeyNil
+ where
+ go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m)
+ go sh ish = shapeyCase sh
+ ish
+ (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish))
+
+uninvert :: forall f n. Shapey f => Inverted f n -> f n
+uninvert = go shapeyNil
+ where
+ go :: forall n' m. f n' -> Inverted f m -> f (n' + m)
+ go sh InvNil | Refl <- lemPlusZero @n' = sh
+ go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish
+
+piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m)
+piindexMatch PIIxEnd ix = Just ix
+piindexMatch (PIIxCons i pix) (IIxCons i' ix)
+ | i == i' = piindexMatch pix ix
+ | otherwise = Nothing
+
+newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse typ SZ () val = case typ of
STNil -> return ()
- STPair{} -> newIORef =<< newAcDense typ val
- STMaybe t -> newIORef =<< traverse (newAcDense t) val
- STArr _ t -> newIORef =<< traverse (newAcSparse t) val
+ STPair{} -> newIORef =<< newAcDense typ SZ () val
+ STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val
+ STArr{} -> newIORef =<< newAcDense typ SZ () val
STScal{} -> newIORef val
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
+newAcSparse typ (SS dep) idx val = case typ of
+ STNil -> return ()
+ STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val
+ STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val
+ STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
-newAcDense :: STy t -> Rep t -> IO (RepAcDense t)
-newAcDense typ val = case typ of
+newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense typ SZ () val = case typ of
STNil -> return ()
- STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val)
+ STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STEither t1 t2 -> case val of
- Left x -> Left <$> newAcSparse t1 x
- Right y -> Right <$> newAcSparse t2 y
- STMaybe t -> traverse (newAcSparse t) val
- STArr _ t -> traverse (newAcSparse t) val
+ Left x -> Left <$> newAcSparse t1 SZ () x
+ Right y -> Right <$> newAcSparse t2 SZ () y
+ STMaybe t -> traverse (newAcSparse t SZ ()) val
+ STArr _ t -> traverse (newAcSparse t SZ ()) val
STScal{} -> return val
STAccum{} -> error "Nested accumulators"
+newAcDense typ (SS dep) idx val = case typ of
+ STNil -> return ()
+ STPair{} -> newAcDense typ (SS dep) idx val
+ STMaybe t -> Just <$> newAcSparse t dep idx val
+ STArr dim (t :: STy t) -> do
+ let sh = unTupRepIdx ShNil ShCons dim (fst val)
+ go (SS dep) dim idx (snd val) $ \arr position ->
+ arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
+ Just i' -> return $ arr `arrayIndex` i'
+ Nothing -> newAcZero t)
+ where
+ go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r
+ go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd
+ go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
+ go (SS dep') (SS dim') (i, idx') val' k =
+ go dep' dim' idx' val' $ \arr pish ->
+ k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
readAcSparse typ val = case typ of
@@ -234,7 +346,7 @@ accumAddSparse typ SZ ref () val = case typ of
AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
-- Oops, ref's contents was still sparse. Have to initialise
-- it first, then try again.
- Nothing -> (ac, do newac <- newAcDense t val'
+ Nothing -> (ac, do newac <- newAcDense t SZ () val'
join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
Nothing -> (Just newac, return ())
Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val'))
@@ -246,12 +358,12 @@ accumAddSparse typ SZ ref () val = case typ of
case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
(_, 0) -> return ()
(0, _) -> do
- newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i))
+ newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t SZ () (arrayIndexLinear val i))
join $ atomicModifyIORef' ref $ \refarr ->
if shapeSize (arrayShape refarr) == 0
then (newrefarr, return ())
else -- someone was faster than us in initialising the reference!
- (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start
+ (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up)
_ | arrayShape refs == arrayShape val ->
sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
| i <- [0 .. shapeSize (arrayShape val) - 1]]
@@ -267,10 +379,25 @@ accumAddSparse typ SZ ref () val = case typ of
accumAddSparse typ (SS dep) ref idx val = case typ of
STNil -> return ()
- STPair t1 t2 -> _ ref idx val
- STMaybe t -> _ ref idx val
+ STPair t1 t2 -> AcM $ do
+ (ref1, ref2) <- readIORef ref
+ case (idx, val) of
+ (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val'
+ (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val'
+ _ -> error "Index/value mismatch in pair accumulator add"
+ STMaybe t ->
+ AcM $ join $ atomicModifyIORef' ref $ \case
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (Nothing, do newac <- newAcDense t dep idx val
+ join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
+ Nothing -> (Just newac, return ())
+ Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
STArr _ t -> _ ref idx val
- STScal{} -> _ ref idx val
+ STScal{} -> error "Cannot index into scalar"
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 680196c..aa2fcc9 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -19,7 +19,7 @@ type family Rep t where
Rep (TScal sty) = ScalRep sty
Rep (TAccum t) = RepAcSparse t
--- Mutable, and has an O(1) zero.
+-- Mutable, and has a zero. The zero may not be O(1), but RepAcSparse (D2 t) will have an O(1) zero.
type family RepAcSparse t where
RepAcSparse TNil = ()
RepAcSparse (TPair a b) = IORef (RepAcDense (TPair a b))