summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-22 23:11:37 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-22 23:11:37 +0200
commit7bc10684870e2249efbdcdddb4950f52d8527699 (patch)
tree8ff3ffd5966ead77edd9f66b61df9e92dc237a47 /src/Interpreter.hs
parent1d14fbd9665b25aef5672e0652d5e7e27bcd4908 (diff)
Interpreter typechecks, at the cost of compositionality of RepAc
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs239
1 files changed, 46 insertions, 193 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 8fb4a78..01d15f1 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -29,7 +29,7 @@ import AST
import CHAD.Types
import Data
import Interpreter.Rep
-import Data.Bifunctor (first, bimap)
+import Data.Bifunctor (bimap)
newtype AcM s a = AcM { unAcM :: IO a }
@@ -194,11 +194,6 @@ data PartialInvIndex n m where
PIIxEnd :: PartialInvIndex m m
PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m
--- | Inverted shape: the outermost dimension is at the /outside/ of this list.
-data PartialInvShape n m where
- PIShEnd :: PartialInvShape m m
- PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m
-
-- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list.
data Inverted (f :: Nat -> Type) n where
InvNil :: Inverted f Z
@@ -269,52 +264,54 @@ piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix)
newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
- STPair{} -> newIORef =<< newAcDense typ SZ () val
+ STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val
- STArr{} -> newIORef =<< newAcDense typ SZ () val
+ STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val
STScal{} -> newIORef val
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
newAcSparse typ (SS dep) idx val = case typ of
STNil -> return ()
- STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val
+ STPair t1 t2 -> newIORef =<< case (idx, val) of
+ (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
+ (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
+ _ -> error "Index/value mismatch in newAc pair"
STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val
- STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val
+ STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val
STScal{} -> error "Cannot index into scalar"
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
+newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t))
+newAcArray _ t SZ _ val =
+ traverse (newAcSparse t SZ ()) val
+newAcArray dim (t :: STy t) dep@SS{} idx val = do
+ let sh = unTupRepIdx ShNil ShCons dim (fst val)
+ go dep dim idx (snd val) $ \arr position ->
+ arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
+ Just i' -> return $ arr `arrayIndex` i'
+ Nothing -> newAcZero t)
+ where
+ go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r
+ go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd
+ go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
+ go (SS dep') (SS dim') (i, idx') val' k =
+ go dep' dim' idx' val' $ \arr pish ->
+ k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
+
newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
newAcDense typ SZ () val = case typ of
- STNil -> return ()
- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> newAcSparse t1 SZ () x
Right y -> Right <$> newAcSparse t2 SZ () y
- STMaybe t -> traverse (newAcSparse t SZ ()) val
- STArr _ t -> traverse (newAcSparse t SZ ()) val
- STScal{} -> return val
- STAccum{} -> error "Nested accumulators"
+ _ -> error "newAcDense: invalid dense type"
newAcDense typ (SS dep) idx val = case typ of
- STNil -> return ()
- STPair{} -> newAcDense typ (SS dep) idx val
- STMaybe t -> Just <$> newAcSparse t dep idx val
- STArr dim (t :: STy t) -> do
- let sh = unTupRepIdx ShNil ShCons dim (fst val)
- go (SS dep) dim idx (snd val) $ \arr position ->
- arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
- Just i' -> return $ arr `arrayIndex` i'
- Nothing -> newAcZero t)
- where
- go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r
- go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd
- go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
- go (SS dep') (SS dim') (i, idx') val' k =
- go dep' dim' idx' val' $ \arr pish ->
- k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
+ STEither t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
+ (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
+ _ -> error "Index/value mismatch in newAc either"
+ _ -> error "newAcDense: invalid dense type"
readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
readAcSparse typ val = case typ of
@@ -330,15 +327,10 @@ readAcSparse typ val = case typ of
readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
readAcDense typ val = case typ of
- STNil -> return ()
- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> readAcSparse t1 x
Right y -> Right <$> readAcSparse t2 y
- STMaybe t -> traverse (readAcSparse t) val
- STArr _ t -> traverse (readAcSparse t) val
- STScal{} -> return val
- STAccum{} -> error "Nested accumulators"
+ _ -> error "readAcDense: invalid dense type"
accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
accumAddSparse typ SZ ref () val = case typ of
@@ -369,7 +361,7 @@ accumAddSparse typ SZ ref () val = case typ of
case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
(_, 0) -> return ()
(0, _) -> do
- newrefarr <- newAcDense typ SZ () val
+ newrefarr <- traverse (newAcSparse t SZ ()) val
join $ atomicModifyIORef' ref $ \refarr ->
if shapeSize (arrayShape refarr) == 0
then (newrefarr, return ())
@@ -410,7 +402,7 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
STArr dim (t :: STy t) -> AcM $ do
refs <- readIORef ref
if shapeSize (arrayShape refs) == 0
- then do newrefarr <- newAcDense typ (SS dep) idx val
+ then do newrefarr <- newAcArray dim t (SS dep) idx val
join $ atomicModifyIORef' ref $ \refarr ->
if shapeSize (arrayShape refarr) == 0
then (newrefarr, return ())
@@ -440,160 +432,21 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
accumAddDense typ SZ ref () val = case typ of
- STNil -> ((), return ())
- STPair t1 t2 ->
- (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
- accumAddSparse t2 SZ (snd ref) () (snd val))
- STMaybe t ->
+ STEither t1 t2 ->
case (ref, val) of
- (_, Nothing) -> (ref, return ())
- (Nothing, Just val') -> _ val'
- (Just ref', Just val') -> _ ref' val'
- STArr _ t -> AcM $ do
- refs <- readIORef ref
- case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
- (_, 0) -> return ()
- (0, _) -> do
- newrefarr <- newAcDense typ SZ () val
- join $ atomicModifyIORef' ref $ \refarr ->
- if shapeSize (arrayShape refarr) == 0
- then (newrefarr, return ())
- else -- someone was faster than us in initialising the reference!
- (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up)
- _ | arrayShape refs == arrayShape val ->
- sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
- | i <- [0 .. shapeSize (arrayShape val) - 1]]
- | otherwise -> error "Array shape mismatch in accum add"
- STScal sty -> AcM $ case sty of
- STI32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STI64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> error "Accumulator of Bool"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
+ (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val')
+ (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val')
+ _ -> error "Mismatched Either in accumAdd"
+ _ -> error "accumAddDense: invalid dense type"
accumAddDense typ (SS dep) ref idx val = case typ of
- STNil -> return ()
- STPair t1 t2 -> AcM $ do
- (ref1, ref2) <- readIORef ref
- case (idx, val) of
- (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val'
- (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val'
- _ -> error "Index/value mismatch in pair accumulator add"
- STMaybe t ->
- AcM $ join $ atomicModifyIORef' ref $ \case
- -- Oops, ref's contents was still sparse. Have to initialise
- -- it first, then try again.
- Nothing -> (Nothing, do newac <- newAcDense t dep idx val
- join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
- Nothing -> (Just newac, return ())
- Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val))
- -- Yep, ref already had a value in there, so we can just add
- -- val' to it recursively.
- Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
- STArr dim (t :: STy t) -> AcM $ do
- refs <- readIORef ref
- if shapeSize (arrayShape refs) == 0
- then do newrefarr <- newAcDense typ (SS dep) idx val
- join $ atomicModifyIORef' ref $ \refarr ->
- if shapeSize (arrayShape refarr) == 0
- then (newrefarr, return ())
- else -- someone was faster than us in initialising the reference!
- (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up)
- else do let sh = unTupRepIdx ShNil ShCons dim (fst val)
- go (SS dep) (invert sh) idx (snd val)
- (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj)
- (\piix subsh val' -> unAcM $ sequence_
- [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix)))
- () (val' `arrayIndex` subix)
- | subix <- enumShape subsh])
- where
- go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
- -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array
- -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray
- -> r
- go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray
- go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element
- go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 =
- go dep' ish idx' val'
- (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj)
- (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm)
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
+ STEither t1 t2 ->
+ case (ref, idx, val) of
+ (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val')
+ (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
+ _ -> error "Mismatched Either in accumAdd"
+ _ -> error "accumAddDense: invalid dense type"
--- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ())
--- accumAddVal typ SZ ac () val = case typ of
--- STNil -> ((), return ())
--- STPair t1 t2 ->
--- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val)
--- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val)
--- in ((ac1', ac2'), m1 >> m2)
--- STMaybe t -> case t of
--- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val)
--- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def
--- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ())
--- def = (ac, accumAddValM t ac val)
--- STArr n t
--- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val
--- STEither{} -> error "Bare Either in accumulator"
--- _ -> _
--- accumAddVal typ (SS dep) ac idx val = case typ of
--- STNil -> ((), return ())
--- STPair t1 t2 ->
--- case (idx, val) of
--- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val'
--- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val'
--- _ -> error "Inconsistent idx and val in accumulator add operation"
--- _ -> _
-
--- accumAddValME :: STy a -> STy b
--- -> IORef (Maybe (Either (RepAc a) (RepAc b)))
--- -> Maybe (Either (Rep a) (Rep b))
--- -> AcM s ()
--- accumAddValME t1 t2 ac val =
--- case val of
--- Nothing -> return ()
--- Just val' ->
--- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of
--- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val))
--- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1
--- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2
--- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either"
-
--- initAccumOrTryAgainME :: STy a -> STy b
--- -> IORef (Maybe (Either (RepAc a) (RepAc b)))
--- -> Either (Rep a) (Rep b)
--- -> IO ()
--- -> IO ()
--- initAccumOrTryAgainME t1 t2 ac val onRace = do
--- newContents <- case val of Left x -> Left <$> makeRepAc t1 x
--- Right y -> Right <$> makeRepAc t2 y
--- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
--- value@Just{} -> (value, onRace))
-
--- accumAddValM :: STy t
--- -> IORef (Maybe (RepAc t))
--- -> Maybe (Rep t)
--- -> AcM s ()
--- accumAddValM t ac val =
--- case val of
--- Nothing -> return ()
--- Just val' ->
--- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of
--- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val))
--- Just x -> first Just $ accumAddVal t SZ x () val'
-
--- initAccumOrTryAgainM :: STy t
--- -> IORef (Maybe (RepAc t))
--- -> Rep t
--- -> IO ()
--- -> IO ()
--- initAccumOrTryAgainM t ac val onRace = do
--- newContents <- makeRepAc t val
--- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
--- value@Just{} -> (value, onRace))
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id