diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-15 11:32:34 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-15 11:32:34 +0100 | 
| commit | 6da98aedf2f28ec8848d1cb8f5605b0c7e64d644 (patch) | |
| tree | c5723c920a200c001fd5f156f5f80f4a6eb11455 /src/Interpreter.hs | |
| parent | 095e7be937c2414cd34eb6288bd2c0856be63def (diff) | |
Complete accumulator revamp!
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 316 | 
1 files changed, 76 insertions, 240 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 11caac0..11184c9 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-}  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE DerivingStrategies #-}  {-# LANGUAGE FlexibleContexts #-} @@ -19,12 +20,10 @@ module Interpreter (    Value(..),  ) where -import Control.Monad (foldM, join, when) -import Data.Bifunctor (bimap) +import Control.Monad (foldM, join, when, forM_)  import Data.Bitraversable (bitraverse)  import Data.Char (isSpace)  import Data.Functor.Identity -import Data.Kind (Type)  import Data.Int (Int64)  import Data.IORef  import System.IO (hPutStrLn, stderr) @@ -64,8 +63,9 @@ interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value en  interpret' env e = do    let dep = ?depth    let lenlimit = max 20 (100 - dep) -  let trunc s | length s > lenlimit = take (lenlimit - 3) s ++ "..." -              | otherwise           = s +  let replace a b = map (\c -> if c == a then b else c) +  let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." +              | otherwise           = replace '\n' ' ' s    when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)    res <- let ?depth = dep + 1 in interpret'Rec env e    when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" @@ -243,7 +243,7 @@ onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =  withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))  withAccum t _ initval f = AcM $ do    accum <- newAcSparse t SAPHere () initval -  out <- case f accum of AcM m -> m +  out <- unAcM $ f accum    val <- readAcSparse t accum    return (out, val) @@ -262,59 +262,6 @@ newAcZero = \case      STBool -> return ()    STAccum{} -> error "Nested accumulators" --- | Inverted index: the outermost index is at the /outside/ of this list. -data PartialInvIndex n m where -  PIIxEnd :: PartialInvIndex m m -  PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m - --- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list. -data Inverted (f :: Nat -> Type) n where -  InvNil :: Inverted f Z -  InvCons :: Int -> Inverted f n -> Inverted f (S n) - -type InvShape = Inverted Shape -type InvIndex = Inverted Index - -class Shapey f where -  shapeyNil :: f Z -  shapeyCons :: f n -> Int -> f (S n) -  shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r -instance Shapey Index where -  shapeyNil = IxNil -  shapeyCons = IxCons -  shapeyCase IxNil k0 _ = k0 -  shapeyCase (IxCons idx i) _ k1 = k1 idx i -instance Shapey Shape where -  shapeyNil = ShNil -  shapeyCons = ShCons -  shapeyCase ShNil k0 _ = k0 -  shapeyCase (ShCons sh n) _ k1 = k1 sh n - -invert :: forall f n. Shapey f => f n -> Inverted f n -invert | Refl <- lemPlusZero @n = flip go InvNil -  where -    go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m) -    go sh ish = shapeyCase sh -                  ish -                  (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish)) - -uninvert :: forall f n. Shapey f => Inverted f n -> f n -uninvert = go shapeyNil -  where -    go :: forall n' m. f n' -> Inverted f m -> f (n' + m) -    go sh InvNil | Refl <- lemPlusZero @n' = sh -    go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish - -piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m) -piindexMatch PIIxEnd ix = Just ix -piindexMatch (PIIxCons i pix) (InvCons i' ix) -  | i == i' = piindexMatch pix ix -  | otherwise = Nothing - -piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n -piindexConcat PIIxEnd ix = ix -piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -  newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)  newAcSparse typ prj idx val = case (typ, prj) of    (STNil, SAPHere) -> return () @@ -353,27 +300,8 @@ onehotArray :: Monad m  onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =    let arrindex = unTupRepIdx IxNil IxCons n arrindex'        arrsh = unTupRepIdx ShNil ShCons n arrsh' -  in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero) - --- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a)) --- newAcDense typ SZ () val = case typ of ---   STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) ---   STEither t1 t2 -> case val of ---     Left x -> Left <$> newAcSparse t1 SZ () x ---     Right y -> Right <$> newAcSparse t2 SZ () y ---   _ -> error "newAcDense: invalid dense type" --- newAcDense typ (SS dep) idx val = case typ of ---   STPair t1 t2 -> ---     case (idx, val) of ---       (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 ---       (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' ---       _ -> error "Index/value mismatch in newAc pair" ---   STEither t1 t2 -> ---     case (idx, val) of ---       (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' ---       (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' ---       _ -> error "Index/value mismatch in newAc either" ---   _ -> error "newAcDense: invalid dense type" +      !linindex = toLinearIndex arrsh arrindex +  in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero)  readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))  readAcSparse typ val = case typ of @@ -390,14 +318,6 @@ readAcSparse typ val = case typ of      STBool -> return ()    STAccum{} -> error "Nested accumulators" --- readAcDense :: STy t -> RepAcDense t -> IO (Rep t) --- readAcDense typ val = case typ of ---   STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) ---   STEither t1 t2 -> case val of ---     Left x -> Left <$> readAcSparse t1 x ---     Right y -> Right <$> readAcSparse t2 y ---   _ -> error "readAcDense: invalid dense type" -  accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()  accumAddSparse typ prj ref idx val = case (typ, prj) of    (STNil, SAPHere) -> return () @@ -406,26 +326,66 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of      case val of        Nothing -> return ()        Just (val1, val2) -> -        AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 -                                          <*> newAcSparse t2 SAPHere () val2) -                                     (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1 -                                                        unAcM $ accumAddSparse t2 SAPHere ac2 () val2) +        realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 +                                    <*> newAcSparse t2 SAPHere () val2) +                               (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 +                                                  accumAddSparse t2 SAPHere ac2 () val2)    (STPair t1 t2, SAPFst prj') -> -    AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) -                                 (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val) +    realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) +                           (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)    (STPair t1 t2, SAPSnd prj') -> -    AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) -                                 (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val) +    realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) +                           (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) -  (STEither t1 t2, SAPHere) -> _ ref val -  (STEither t1 _, SAPLeft prj') -> _ ref idx val -  (STEither _ t2, SAPRight prj') -> _ ref idx val +  (STEither{}, SAPHere) -> +    case val of +      Nothing -> return () +      Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 +      Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 +  (STEither t1 _, SAPLeft prj') -> +    realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) +                           (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val +                                  Right{} -> error "Mismatched Either in accumAddSparse (r +l)") +  (STEither _ t2, SAPRight prj') -> +    realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) +                           (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val +                                  Left{} -> error "Mismatched Either in accumAddSparse (l +r)") -  (STMaybe t1, SAPHere) -> _ ref val -  (STMaybe t1, SAPJust prj') -> _ ref idx val +  (STMaybe{}, SAPHere) -> +    case val of +      Nothing -> return () +      Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' +  (STMaybe t1, SAPJust prj') -> +      realiseMaybeSparse ref (newAcSparse t1 prj' idx val) +                             (\ac -> accumAddSparse t1 prj' ac idx val) -  (STArr _ t1, SAPHere) -> _ ref val -  (STArr n t, SAPArrIdx prj' _) -> _ ref idx val +  (STArr _ t1, SAPHere) -> +    let add ac = forM_ [0 .. arraySize ac - 1] $ \i -> +                   unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i) +    in if arraySize val == 0 +         then return () +         else AcM $ join $ atomicModifyIORef' ref $ \ac -> +                if arraySize ac == 0 +                  then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val +                               join $ atomicModifyIORef' ref $ \ac' -> +                                 if arraySize ac == 0 +                                   then (newac, return ()) +                                   else (ac', add ac')) +                  else (ac, add ac) +  (STArr n t1, SAPArrIdx prj' _) -> +    let ((arrindex', arrsh'), idx') = idx +        arrindex = unTupRepIdx IxNil IxCons n arrindex' +        arrsh = unTupRepIdx ShNil ShCons n arrsh' +        linindex = toLinearIndex arrsh arrindex +        add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val +    in AcM $ join $ atomicModifyIORef' ref $ \ac -> +         if arraySize ac == 0 +           then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx +                        join $ atomicModifyIORef' ref $ \ac' -> +                          if arraySize ac == 0 +                            then (newac, return ()) +                            else (ac', add ac')) +           else (ac, add ac)    (STScal sty, SAPHere) -> AcM $ case sty of      STI32 -> return () @@ -436,145 +396,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of    (STAccum{}, _) -> error "Accumulators not allowed in source program" -realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO () +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()  realiseMaybeSparse ref makeval modifyval =    -- Try modifying what's already in ref. The 'join' makes the snd    -- of the function's return value a _continuation_ that is run after    -- the critical section ends. -  join $ atomicModifyIORef' ref $ \ac -> case ac of -           -- Oops, ref's contents was still sparse. Have to initialise -           -- it first, then try again. -           Nothing -> (ac, do val <- makeval -                              join $ atomicModifyIORef' ref $ \ac' -> case ac' of -                                       Nothing -> (Just val, return ()) -                                       Just val' -> (ac', modifyval val')) -           -- Yep, ref already had a value in there, so we can just add -           -- val' to it recursively. -           Just val -> (ac, modifyval val) - -{- -accumAddSparse typ SZ ref () val = case typ of -  STNil -> return () -  STPair t1 t2 -> AcM $ do -    (r1, r2) <- readIORef ref -    unAcM $ accumAddSparse t1 SZ r1 () (fst val) -    unAcM $ accumAddSparse t2 SZ r2 () (snd val) -  STMaybe t -> -    case val of -      Nothing -> return () -      Just val' -> -        -- Try adding val' to what's already in ref. The 'join' makes the snd -        -- of the function's return value a _continuation_ that is run after -        -- the critical section ends. -        AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of -                       -- Oops, ref's contents was still sparse. Have to initialise -                       -- it first, then try again. -                       Nothing -> (ac, do newac <- newAcDense t SZ () val' -                                          join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of -                                                   Nothing -> (Just newac, return ()) -                                                   Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val')) -                       -- Yep, ref already had a value in there, so we can just add -                       -- val' to it recursively. -                       Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val') -  STArr _ t -> AcM $ do -    refs <- readIORef ref -    case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of -      (_, 0) -> return () -      (0, _) -> do -        newrefarr <- traverse (newAcSparse t SZ ()) val -        join $ atomicModifyIORef' ref $ \refarr -> -                 if shapeSize (arrayShape refarr) == 0 -                   then (newrefarr, return ()) -                   else -- someone was faster than us in initialising the reference! -                        (refarr, unAcM $ accumAddSparse typ SZ ref () val)  -- just try again from the start (dropping newrefarr for the GC to clean up) -      _ | arrayShape refs == arrayShape val -> -            sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) -                      | i <- [0 .. shapeSize (arrayShape val) - 1]] -        | otherwise -> error "Array shape mismatch in accum add" -  STScal sty -> AcM $ case sty of -    STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STBool -> error "Accumulator of Bool" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" - -accumAddSparse typ (SS dep) ref idx val = case typ of -  STNil -> return () -  STPair t1 t2 -> AcM $ do -    (ref1, ref2) <- readIORef ref -    case (idx, val) of -      (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' -      (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' -      _ -> error "Index/value mismatch in pair accumulator add" -  STMaybe t -> -    AcM $ join $ atomicModifyIORef' ref $ \case -                   -- Oops, ref's contents was still sparse. Have to initialise -                   -- it first, then try again. -                   Nothing -> (Nothing, do newac <- newAcDense t dep idx val -                                           join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of -                                                    Nothing -> (Just newac, return ()) -                                                    Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) -                   -- Yep, ref already had a value in there, so we can just add -                   -- val' to it recursively. -                   Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) -  STArr dim (t :: STy t) -> AcM $ do -    refs <- readIORef ref -    if shapeSize (arrayShape refs) == 0 -      then do newrefarr <- newAcArray dim t (SS dep) idx val -              join $ atomicModifyIORef' ref $ \refarr -> -                       if shapeSize (arrayShape refarr) == 0 -                         then (newrefarr, return ()) -                         else -- someone was faster than us in initialising the reference! -                              (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val)  -- just try again from the start (dropping newrefarr for the GC to clean up) -      else do let sh = unTupRepIdx ShNil ShCons dim (fst val) -              go (SS dep) (invert sh) idx (snd val) -                (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) -                (\piix subsh val' -> unAcM $ sequence_ -                                       [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) -                                                       () (val' `arrayIndex` subix) -                                       | subix <- enumShape subsh]) -    where -      go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -         -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r)  -- ^ Indexing into element of the array -         -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r)  -- ^ Accumulating onto a subarray -         -> r -      go SZ        ish             ()        val' _  k0 = k0 PIIxEnd (uninvert ish) val'  -- ^ Ran out of AcIdx: accumulating onto subarray -      go (SS dep') InvNil          idx'      val' kj _  = kj dep' IxNil idx' val'  -- ^ Ran out of array dimensions: accumulating into (part of) element -      go (SS dep') (InvCons _ ish) (i, idx') val' kj k0 = -        go dep' ish idx' val' -          (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) -          (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" - -accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) -accumAddDense typ SZ ref () val = case typ of -  STPair t1 t2 -> -    (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) -             accumAddSparse t2 SZ (snd ref) () (snd val)) -  STEither t1 t2 -> -    case (ref, val) of -      (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') -      (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') -      _ -> error "Mismatched Either in accumAddDense either" -  _ -> error "accumAddDense: invalid dense type" - -accumAddDense typ (SS dep) ref idx val = case typ of -  STPair t1 t2 -> -    case (idx, val) of -      (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val') -      (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val') -      _ -> error "Mismatched Either in accumAddDense pair" -  STEither t1 t2 -> -    case (ref, idx, val) of -      (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') -      (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') -      _ -> error "Mismatched Either in accumAddDense either" -  _ -> error "accumAddDense: invalid dense type" --} +  AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of +    -- Oops, ref's contents was still sparse. Have to initialise +    -- it first, then try again. +    Nothing -> (ac, do val <- makeval +                       join $ atomicModifyIORef' ref $ \ac' -> case ac' of +                                Nothing -> (Just val, return ()) +                                Just val' -> (ac', unAcM $ modifyval val')) +    -- Yep, ref already had a value in there, so we can just add +    -- val' to it recursively. +    Just val -> (ac, unAcM $ modifyval val)  numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r @@ -593,12 +429,12 @@ integralIsIntegral STI64 = id  unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))              -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n -unTupRepIdx nil _    SZ _ = nil +unTupRepIdx nil _    SZ     _        = nil  unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i  tupRepIdx :: (forall m. f (S m) -> (f m, Int))            -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) -tupRepIdx _      SZ _ = () +tupRepIdx _      SZ     _   = ()  tupRepIdx uncons (SS n) tup =    let (tup', i) = uncons tup    in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i | 
