diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-22 23:11:37 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-22 23:11:37 +0200 | 
| commit | 7bc10684870e2249efbdcdddb4950f52d8527699 (patch) | |
| tree | 8ff3ffd5966ead77edd9f66b61df9e92dc237a47 /src/Interpreter.hs | |
| parent | 1d14fbd9665b25aef5672e0652d5e7e27bcd4908 (diff) | |
Interpreter typechecks, at the cost of compositionality of RepAc
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 239 | 
1 files changed, 46 insertions, 193 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 8fb4a78..01d15f1 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -29,7 +29,7 @@ import AST  import CHAD.Types  import Data  import Interpreter.Rep -import Data.Bifunctor (first, bimap) +import Data.Bifunctor (bimap)  newtype AcM s a = AcM { unAcM :: IO a } @@ -194,11 +194,6 @@ data PartialInvIndex n m where    PIIxEnd :: PartialInvIndex m m    PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m --- | Inverted shape: the outermost dimension is at the /outside/ of this list. -data PartialInvShape n m where -  PIShEnd :: PartialInvShape m m -  PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m -  -- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list.  data Inverted (f :: Nat -> Type) n where    InvNil :: Inverted f Z @@ -269,52 +264,54 @@ piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix)  newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)  newAcSparse typ SZ () val = case typ of    STNil -> return () -  STPair{} -> newIORef =<< newAcDense typ SZ () val +  STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)    STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val -  STArr{} -> newIORef =<< newAcDense typ SZ () val +  STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val    STScal{} -> newIORef val    STAccum{} -> error "Nested accumulators"    STEither{} -> error "Bare Either in accumulator"  newAcSparse typ (SS dep) idx val = case typ of    STNil -> return () -  STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val +  STPair t1 t2 -> newIORef =<< case (idx, val) of +                                 (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 +                                 (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' +                                 _ -> error "Index/value mismatch in newAc pair"    STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val -  STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val +  STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val    STScal{} -> error "Cannot index into scalar"    STAccum{} -> error "Nested accumulators"    STEither{} -> error "Bare Either in accumulator" +newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) +newAcArray _ t SZ _ val = +  traverse (newAcSparse t SZ ()) val +newAcArray dim (t :: STy t) dep@SS{} idx val = do +  let sh = unTupRepIdx ShNil ShCons dim (fst val) +  go dep dim idx (snd val) $ \arr position -> +    arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of +                               Just i' -> return $ arr `arrayIndex` i' +                               Nothing -> newAcZero t) +  where +    go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r +    go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd +    go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd +    go (SS dep') (SS dim') (i, idx') val' k = +      go dep' dim' idx' val' $ \arr pish -> +        k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) +  newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)  newAcDense typ SZ () val = case typ of -  STNil -> return () -  STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)    STEither t1 t2 -> case val of      Left x -> Left <$> newAcSparse t1 SZ () x      Right y -> Right <$> newAcSparse t2 SZ () y -  STMaybe t -> traverse (newAcSparse t SZ ()) val -  STArr _ t -> traverse (newAcSparse t SZ ()) val -  STScal{} -> return val -  STAccum{} -> error "Nested accumulators" +  _ -> error "newAcDense: invalid dense type"  newAcDense typ (SS dep) idx val = case typ of -  STNil -> return () -  STPair{} -> newAcDense typ (SS dep) idx val -  STMaybe t -> Just <$> newAcSparse t dep idx val -  STArr dim (t :: STy t) -> do -    let sh = unTupRepIdx ShNil ShCons dim (fst val) -    go (SS dep) dim idx (snd val) $ \arr position -> -      arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of -                                 Just i' -> return $ arr `arrayIndex` i' -                                 Nothing -> newAcZero t) -    where -      go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r -      go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd -      go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd -      go (SS dep') (SS dim') (i, idx') val' k = -        go dep' dim' idx' val' $ \arr pish -> -          k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" +  STEither t1 t2 -> +    case (idx, val) of +      (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' +      (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' +      _ -> error "Index/value mismatch in newAc either" +  _ -> error "newAcDense: invalid dense type"  readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)  readAcSparse typ val = case typ of @@ -330,15 +327,10 @@ readAcSparse typ val = case typ of  readAcDense :: STy t -> RepAcDense t -> IO (Rep t)  readAcDense typ val = case typ of -  STNil -> return () -  STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)    STEither t1 t2 -> case val of      Left x -> Left <$> readAcSparse t1 x      Right y -> Right <$> readAcSparse t2 y -  STMaybe t -> traverse (readAcSparse t) val -  STArr _ t -> traverse (readAcSparse t) val -  STScal{} -> return val -  STAccum{} -> error "Nested accumulators" +  _ -> error "readAcDense: invalid dense type"  accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()  accumAddSparse typ SZ ref () val = case typ of @@ -369,7 +361,7 @@ accumAddSparse typ SZ ref () val = case typ of      case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of        (_, 0) -> return ()        (0, _) -> do -        newrefarr <- newAcDense typ SZ () val +        newrefarr <- traverse (newAcSparse t SZ ()) val          join $ atomicModifyIORef' ref $ \refarr ->                   if shapeSize (arrayShape refarr) == 0                     then (newrefarr, return ()) @@ -410,7 +402,7 @@ accumAddSparse typ (SS dep) ref idx val = case typ of    STArr dim (t :: STy t) -> AcM $ do      refs <- readIORef ref      if shapeSize (arrayShape refs) == 0 -      then do newrefarr <- newAcDense typ (SS dep) idx val +      then do newrefarr <- newAcArray dim t (SS dep) idx val                join $ atomicModifyIORef' ref $ \refarr ->                         if shapeSize (arrayShape refarr) == 0                           then (newrefarr, return ()) @@ -440,160 +432,21 @@ accumAddSparse typ (SS dep) ref idx val = case typ of  accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())  accumAddDense typ SZ ref () val = case typ of -  STNil -> ((), return ()) -  STPair t1 t2 -> -    (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) -             accumAddSparse t2 SZ (snd ref) () (snd val)) -  STMaybe t -> +  STEither t1 t2 ->      case (ref, val) of -      (_, Nothing) -> (ref, return ()) -      (Nothing, Just val') -> _ val' -      (Just ref', Just val') -> _ ref' val' -  STArr _ t -> AcM $ do -    refs <- readIORef ref -    case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of -      (_, 0) -> return () -      (0, _) -> do -        newrefarr <- newAcDense typ SZ () val -        join $ atomicModifyIORef' ref $ \refarr -> -                 if shapeSize (arrayShape refarr) == 0 -                   then (newrefarr, return ()) -                   else -- someone was faster than us in initialising the reference! -                        (refarr, unAcM $ accumAddSparse typ SZ ref () val)  -- just try again from the start (dropping newrefarr for the GC to clean up) -      _ | arrayShape refs == arrayShape val -> -            sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) -                      | i <- [0 .. shapeSize (arrayShape val) - 1]] -        | otherwise -> error "Array shape mismatch in accum add" -  STScal sty -> AcM $ case sty of -    STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STBool -> error "Accumulator of Bool" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" +      (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') +      (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') +      _ -> error "Mismatched Either in accumAdd" +  _ -> error "accumAddDense: invalid dense type"  accumAddDense typ (SS dep) ref idx val = case typ of -  STNil -> return () -  STPair t1 t2 -> AcM $ do -    (ref1, ref2) <- readIORef ref -    case (idx, val) of -      (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' -      (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' -      _ -> error "Index/value mismatch in pair accumulator add" -  STMaybe t -> -    AcM $ join $ atomicModifyIORef' ref $ \case -                   -- Oops, ref's contents was still sparse. Have to initialise -                   -- it first, then try again. -                   Nothing -> (Nothing, do newac <- newAcDense t dep idx val -                                           join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of -                                                    Nothing -> (Just newac, return ()) -                                                    Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) -                   -- Yep, ref already had a value in there, so we can just add -                   -- val' to it recursively. -                   Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) -  STArr dim (t :: STy t) -> AcM $ do -    refs <- readIORef ref -    if shapeSize (arrayShape refs) == 0 -      then do newrefarr <- newAcDense typ (SS dep) idx val -              join $ atomicModifyIORef' ref $ \refarr -> -                       if shapeSize (arrayShape refarr) == 0 -                         then (newrefarr, return ()) -                         else -- someone was faster than us in initialising the reference! -                              (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val)  -- just try again from the start (dropping newrefarr for the GC to clean up) -      else do let sh = unTupRepIdx ShNil ShCons dim (fst val) -              go (SS dep) (invert sh) idx (snd val) -                (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) -                (\piix subsh val' -> unAcM $ sequence_ -                                       [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) -                                                       () (val' `arrayIndex` subix) -                                       | subix <- enumShape subsh]) -    where -      go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -         -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r)  -- ^ Indexing into element of the array -         -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r)  -- ^ Accumulating onto a subarray -         -> r -      go SZ        ish             ()        val' _  k0 = k0 PIIxEnd (uninvert ish) val'  -- ^ Ran out of AcIdx: accumulating onto subarray -      go (SS dep') IShNil          idx'      val' kj _  = kj dep' IxNil idx' val'  -- ^ Ran out of array dimensions: accumulating into (part of) element -      go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 = -        go dep' ish idx' val' -          (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) -          (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" - --- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) --- accumAddVal typ SZ ac () val = case typ of ---   STNil -> ((), return ()) ---   STPair t1 t2 -> ---     let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val) ---         (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val) ---     in ((ac1', ac2'), m1 >> m2) ---   STMaybe t -> case t of ---     STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val) ---     STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def ---     where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ()) ---           def = (ac, accumAddValM t ac val) ---   STArr n t ---     | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val ---   STEither{} -> error "Bare Either in accumulator" ---   _ -> _ --- accumAddVal typ (SS dep) ac idx val = case typ of ---   STNil -> ((), return ()) ---   STPair t1 t2 -> ---     case (idx, val) of ---       (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val' ---       (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val' ---       _ -> error "Inconsistent idx and val in accumulator add operation" ---   _ -> _ - --- accumAddValME :: STy a -> STy b ---               -> IORef (Maybe (Either (RepAc a) (RepAc b))) ---               -> Maybe (Either (Rep a) (Rep b)) ---               -> AcM s () --- accumAddValME t1 t2 ac val = ---   case val of ---     Nothing -> return () ---     Just val' -> ---       join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of ---                      (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val)) ---                      (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1 ---                      (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2 ---                      _ -> error "Inconsistent accumulator and value in add operation on Maybe Either" - --- initAccumOrTryAgainME :: STy a -> STy b ---                       -> IORef (Maybe (Either (RepAc a) (RepAc b))) ---                       -> Either (Rep a) (Rep b) ---                       -> IO () ---                       -> IO () --- initAccumOrTryAgainME t1 t2 ac val onRace = do ---   newContents <- case val of Left x -> Left <$> makeRepAc t1 x ---                              Right y -> Right <$> makeRepAc t2 y ---   join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) ---                                       value@Just{} -> (value, onRace)) - --- accumAddValM :: STy t ---              -> IORef (Maybe (RepAc t)) ---              -> Maybe (Rep t) ---              -> AcM s () --- accumAddValM t ac val = ---   case val of ---     Nothing -> return () ---     Just val' -> ---       join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of ---                      Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val)) ---                      Just x -> first Just $ accumAddVal t SZ x () val' +  STEither t1 t2 -> +    case (ref, idx, val) of +      (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') +      (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') +      _ -> error "Mismatched Either in accumAdd" +  _ -> error "accumAddDense: invalid dense type" --- initAccumOrTryAgainM :: STy t ---                      -> IORef (Maybe (RepAc t)) ---                      -> Rep t ---                      -> IO () ---                      -> IO () --- initAccumOrTryAgainM t ac val onRace = do ---   newContents <- makeRepAc t val ---   join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) ---                                       value@Just{} -> (value, onRace))  numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r  numericIsNum STI32 = id | 
