diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 171 |
1 files changed, 149 insertions, 22 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 62160aa..d2b8074 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,15 +1,17 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE TupleSections #-} module Interpreter ( interpret, interpret', @@ -17,6 +19,7 @@ module Interpreter ( ) where import Control.Monad (foldM, join) +import Data.Kind (Type) import Data.Int (Int64) import Data.IORef import System.IO.Unsafe (unsafePerformIO) @@ -166,32 +169,141 @@ addD2s typ a b = case typ of withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - accum <- newAcSparse t initval + accum <- newAcSparse t SZ () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum return (out, val) -newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t) -newAcSparse typ val = case typ of +newAcZero :: STy t -> IO (RepAcSparse t) +newAcZero = \case + STNil -> return () + STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2 + STMaybe _ -> newIORef Nothing + STArr n _ -> newIORef (emptyArray n) + STScal sty -> case sty of + STI32 -> newIORef 0 + STI64 -> newIORef 0 + STF32 -> newIORef 0.0 + STF64 -> newIORef 0.0 + STBool -> error "Accumulator of Bool" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +-- | Inverted index: the outermost index is at the /outside/ of this list. +data PartialInvIndex n m where + PIIxEnd :: PartialInvIndex m m + PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m + +-- | Inverted shape: the outermost dimension is at the /outside/ of this list. +data PartialInvShape n m where + PIShEnd :: PartialInvShape m m + PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m + +-- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list. +data Inverted (f :: Nat -> Type) n where + InvNil :: Inverted f Z + InvCons :: Int -> Inverted f n -> Inverted f (S n) + +type InvShape = Inverted Shape +type InvIndex = Inverted Index + +pattern IIxNil :: () => n ~ Z => InvIndex n +pattern IIxNil = InvNil +pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn +pattern IIxCons i ix = InvCons i ix +{-# COMPLETE IIxNil, IIxCons #-} + +pattern IShNil :: () => n ~ Z => InvShape Z +pattern IShNil = InvNil +pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn +pattern IShCons n sh = InvCons n sh +{-# COMPLETE IShNil, IShCons #-} + +class Shapey f where + shapeyNil :: f Z + shapeyCons :: f n -> Int -> f (S n) + shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r +instance Shapey Index where + shapeyNil = IxNil + shapeyCons = IxCons + shapeyCase IxNil k0 _ = k0 + shapeyCase (IxCons idx i) _ k1 = k1 idx i +instance Shapey Shape where + shapeyNil = ShNil + shapeyCons = ShCons + shapeyCase ShNil k0 _ = k0 + shapeyCase (ShCons sh n) _ k1 = k1 sh n + +invert :: forall f n. Shapey f => f n -> Inverted f n +invert | Refl <- lemPlusZero @n = flip go shapeyNil + where + go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m) + go sh ish = shapeyCase sh + ish + (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish)) + +uninvert :: forall f n. Shapey f => Inverted f n -> f n +uninvert = go shapeyNil + where + go :: forall n' m. f n' -> Inverted f m -> f (n' + m) + go sh InvNil | Refl <- lemPlusZero @n' = sh + go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish + +piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m) +piindexMatch PIIxEnd ix = Just ix +piindexMatch (PIIxCons i pix) (IIxCons i' ix) + | i == i' = piindexMatch pix ix + | otherwise = Nothing + +newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse typ SZ () val = case typ of STNil -> return () - STPair{} -> newIORef =<< newAcDense typ val - STMaybe t -> newIORef =<< traverse (newAcDense t) val - STArr _ t -> newIORef =<< traverse (newAcSparse t) val + STPair{} -> newIORef =<< newAcDense typ SZ () val + STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val + STArr{} -> newIORef =<< newAcDense typ SZ () val STScal{} -> newIORef val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" +newAcSparse typ (SS dep) idx val = case typ of + STNil -> return () + STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val + STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val + STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" -newAcDense :: STy t -> Rep t -> IO (RepAcDense t) -newAcDense typ val = case typ of +newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense typ SZ () val = case typ of STNil -> return () - STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val) + STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STEither t1 t2 -> case val of - Left x -> Left <$> newAcSparse t1 x - Right y -> Right <$> newAcSparse t2 y - STMaybe t -> traverse (newAcSparse t) val - STArr _ t -> traverse (newAcSparse t) val + Left x -> Left <$> newAcSparse t1 SZ () x + Right y -> Right <$> newAcSparse t2 SZ () y + STMaybe t -> traverse (newAcSparse t SZ ()) val + STArr _ t -> traverse (newAcSparse t SZ ()) val STScal{} -> return val STAccum{} -> error "Nested accumulators" +newAcDense typ (SS dep) idx val = case typ of + STNil -> return () + STPair{} -> newAcDense typ (SS dep) idx val + STMaybe t -> Just <$> newAcSparse t dep idx val + STArr dim (t :: STy t) -> do + let sh = unTupRepIdx ShNil ShCons dim (fst val) + go (SS dep) dim idx (snd val) $ \arr position -> + arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of + Just i' -> return $ arr `arrayIndex` i' + Nothing -> newAcZero t) + where + go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r + go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd + go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd + go (SS dep') (SS dim') (i, idx') val' k = + go dep' dim' idx' val' $ \arr pish -> + k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) readAcSparse typ val = case typ of @@ -234,7 +346,7 @@ accumAddSparse typ SZ ref () val = case typ of AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of -- Oops, ref's contents was still sparse. Have to initialise -- it first, then try again. - Nothing -> (ac, do newac <- newAcDense t val' + Nothing -> (ac, do newac <- newAcDense t SZ () val' join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of Nothing -> (Just newac, return ()) Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val')) @@ -246,12 +358,12 @@ accumAddSparse typ SZ ref () val = case typ of case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of (_, 0) -> return () (0, _) -> do - newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i)) + newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t SZ () (arrayIndexLinear val i)) join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start + (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up) _ | arrayShape refs == arrayShape val -> sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) | i <- [0 .. shapeSize (arrayShape val) - 1]] @@ -267,10 +379,25 @@ accumAddSparse typ SZ ref () val = case typ of accumAddSparse typ (SS dep) ref idx val = case typ of STNil -> return () - STPair t1 t2 -> _ ref idx val - STMaybe t -> _ ref idx val + STPair t1 t2 -> AcM $ do + (ref1, ref2) <- readIORef ref + case (idx, val) of + (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' + (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' + _ -> error "Index/value mismatch in pair accumulator add" + STMaybe t -> + AcM $ join $ atomicModifyIORef' ref $ \case + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (Nothing, do newac <- newAcDense t dep idx val + join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of + Nothing -> (Just newac, return ()) + Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) STArr _ t -> _ ref idx val - STScal{} -> _ ref idx val + STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" |