diff options
-rw-r--r-- | src/AST.hs | 9 | ||||
-rw-r--r-- | src/Array.hs | 9 | ||||
-rw-r--r-- | src/Interpreter.hs | 171 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 2 |
4 files changed, 166 insertions, 25 deletions
@@ -13,6 +13,7 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module AST (module AST, module AST.Types, module AST.Weaken) where import Data.Functor.Const @@ -43,8 +44,12 @@ type family AcVal t i where AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) AcVal (TMaybe t) (S i) = AcVal t i - AcVal (TArr Z t) (S i) = AcVal t i - AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i + AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i)) + +type family AcValArr n t i where + AcValArr n t Z = TArr n t + AcValArr Z t (S i) = AcVal t i + AcValArr (S n) t (S i) = AcValArr n t i -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner diff --git a/src/Array.hs b/src/Array.hs index d7dadbf..6473bf0 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -59,6 +59,9 @@ arraySize (Array sh _) = shapeSize sh emptyArray :: SNat n -> Array n t emptyArray n = Array (emptyShape n) V.empty +arrayUnit :: t -> Array Z t +arrayUnit x = Array ShNil (V.singleton x) + arrayIndex :: Array n t -> Index n -> t arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) @@ -80,6 +83,12 @@ arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f +arrayMap :: (a -> b) -> Array n a -> Array n b +arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr) + +arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b) +arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr) + -- | The Int is the linear index of the value. traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 62160aa..d2b8074 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,15 +1,17 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE TupleSections #-} module Interpreter ( interpret, interpret', @@ -17,6 +19,7 @@ module Interpreter ( ) where import Control.Monad (foldM, join) +import Data.Kind (Type) import Data.Int (Int64) import Data.IORef import System.IO.Unsafe (unsafePerformIO) @@ -166,32 +169,141 @@ addD2s typ a b = case typ of withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - accum <- newAcSparse t initval + accum <- newAcSparse t SZ () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum return (out, val) -newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t) -newAcSparse typ val = case typ of +newAcZero :: STy t -> IO (RepAcSparse t) +newAcZero = \case + STNil -> return () + STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2 + STMaybe _ -> newIORef Nothing + STArr n _ -> newIORef (emptyArray n) + STScal sty -> case sty of + STI32 -> newIORef 0 + STI64 -> newIORef 0 + STF32 -> newIORef 0.0 + STF64 -> newIORef 0.0 + STBool -> error "Accumulator of Bool" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +-- | Inverted index: the outermost index is at the /outside/ of this list. +data PartialInvIndex n m where + PIIxEnd :: PartialInvIndex m m + PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m + +-- | Inverted shape: the outermost dimension is at the /outside/ of this list. +data PartialInvShape n m where + PIShEnd :: PartialInvShape m m + PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m + +-- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list. +data Inverted (f :: Nat -> Type) n where + InvNil :: Inverted f Z + InvCons :: Int -> Inverted f n -> Inverted f (S n) + +type InvShape = Inverted Shape +type InvIndex = Inverted Index + +pattern IIxNil :: () => n ~ Z => InvIndex n +pattern IIxNil = InvNil +pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn +pattern IIxCons i ix = InvCons i ix +{-# COMPLETE IIxNil, IIxCons #-} + +pattern IShNil :: () => n ~ Z => InvShape Z +pattern IShNil = InvNil +pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn +pattern IShCons n sh = InvCons n sh +{-# COMPLETE IShNil, IShCons #-} + +class Shapey f where + shapeyNil :: f Z + shapeyCons :: f n -> Int -> f (S n) + shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r +instance Shapey Index where + shapeyNil = IxNil + shapeyCons = IxCons + shapeyCase IxNil k0 _ = k0 + shapeyCase (IxCons idx i) _ k1 = k1 idx i +instance Shapey Shape where + shapeyNil = ShNil + shapeyCons = ShCons + shapeyCase ShNil k0 _ = k0 + shapeyCase (ShCons sh n) _ k1 = k1 sh n + +invert :: forall f n. Shapey f => f n -> Inverted f n +invert | Refl <- lemPlusZero @n = flip go shapeyNil + where + go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m) + go sh ish = shapeyCase sh + ish + (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish)) + +uninvert :: forall f n. Shapey f => Inverted f n -> f n +uninvert = go shapeyNil + where + go :: forall n' m. f n' -> Inverted f m -> f (n' + m) + go sh InvNil | Refl <- lemPlusZero @n' = sh + go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish + +piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m) +piindexMatch PIIxEnd ix = Just ix +piindexMatch (PIIxCons i pix) (IIxCons i' ix) + | i == i' = piindexMatch pix ix + | otherwise = Nothing + +newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse typ SZ () val = case typ of STNil -> return () - STPair{} -> newIORef =<< newAcDense typ val - STMaybe t -> newIORef =<< traverse (newAcDense t) val - STArr _ t -> newIORef =<< traverse (newAcSparse t) val + STPair{} -> newIORef =<< newAcDense typ SZ () val + STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val + STArr{} -> newIORef =<< newAcDense typ SZ () val STScal{} -> newIORef val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" +newAcSparse typ (SS dep) idx val = case typ of + STNil -> return () + STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val + STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val + STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" -newAcDense :: STy t -> Rep t -> IO (RepAcDense t) -newAcDense typ val = case typ of +newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense typ SZ () val = case typ of STNil -> return () - STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val) + STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STEither t1 t2 -> case val of - Left x -> Left <$> newAcSparse t1 x - Right y -> Right <$> newAcSparse t2 y - STMaybe t -> traverse (newAcSparse t) val - STArr _ t -> traverse (newAcSparse t) val + Left x -> Left <$> newAcSparse t1 SZ () x + Right y -> Right <$> newAcSparse t2 SZ () y + STMaybe t -> traverse (newAcSparse t SZ ()) val + STArr _ t -> traverse (newAcSparse t SZ ()) val STScal{} -> return val STAccum{} -> error "Nested accumulators" +newAcDense typ (SS dep) idx val = case typ of + STNil -> return () + STPair{} -> newAcDense typ (SS dep) idx val + STMaybe t -> Just <$> newAcSparse t dep idx val + STArr dim (t :: STy t) -> do + let sh = unTupRepIdx ShNil ShCons dim (fst val) + go (SS dep) dim idx (snd val) $ \arr position -> + arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of + Just i' -> return $ arr `arrayIndex` i' + Nothing -> newAcZero t) + where + go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r + go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd + go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd + go (SS dep') (SS dim') (i, idx') val' k = + go dep' dim' idx' val' $ \arr pish -> + k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) readAcSparse typ val = case typ of @@ -234,7 +346,7 @@ accumAddSparse typ SZ ref () val = case typ of AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of -- Oops, ref's contents was still sparse. Have to initialise -- it first, then try again. - Nothing -> (ac, do newac <- newAcDense t val' + Nothing -> (ac, do newac <- newAcDense t SZ () val' join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of Nothing -> (Just newac, return ()) Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val')) @@ -246,12 +358,12 @@ accumAddSparse typ SZ ref () val = case typ of case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of (_, 0) -> return () (0, _) -> do - newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i)) + newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t SZ () (arrayIndexLinear val i)) join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start + (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up) _ | arrayShape refs == arrayShape val -> sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) | i <- [0 .. shapeSize (arrayShape val) - 1]] @@ -267,10 +379,25 @@ accumAddSparse typ SZ ref () val = case typ of accumAddSparse typ (SS dep) ref idx val = case typ of STNil -> return () - STPair t1 t2 -> _ ref idx val - STMaybe t -> _ ref idx val + STPair t1 t2 -> AcM $ do + (ref1, ref2) <- readIORef ref + case (idx, val) of + (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' + (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' + _ -> error "Index/value mismatch in pair accumulator add" + STMaybe t -> + AcM $ join $ atomicModifyIORef' ref $ \case + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (Nothing, do newac <- newAcDense t dep idx val + join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of + Nothing -> (Just newac, return ()) + Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) STArr _ t -> _ ref idx val - STScal{} -> _ ref idx val + STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 680196c..aa2fcc9 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -19,7 +19,7 @@ type family Rep t where Rep (TScal sty) = ScalRep sty Rep (TAccum t) = RepAcSparse t --- Mutable, and has an O(1) zero. +-- Mutable, and has a zero. The zero may not be O(1), but RepAcSparse (D2 t) will have an O(1) zero. type family RepAcSparse t where RepAcSparse TNil = () RepAcSparse (TPair a b) = IORef (RepAcDense (TPair a b)) |