diff options
Diffstat (limited to 'src/CHAD/Interpreter.hs')
| -rw-r--r-- | src/CHAD/Interpreter.hs | 471 |
1 files changed, 471 insertions, 0 deletions
diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs new file mode 100644 index 0000000..a9421e6 --- /dev/null +++ b/src/CHAD/Interpreter.hs @@ -0,0 +1,471 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Interpreter ( + interpret, + interpretOpen, + Value(..), +) where + +import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) +import Data.Bifunctor (bimap) +import Data.Bitraversable (bitraverse) +import Data.Char (isSpace) +import Data.Functor.Identity +import qualified Data.Functor.Product as Product +import Data.Int (Int64) +import Data.IORef +import Data.Tuple (swap) +import System.IO (hPutStrLn, stderr) +import System.IO.Unsafe (unsafePerformIO) + +import Debug.Trace + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Interpreter.Rep + + +newtype AcM s a = AcM { unAcM :: IO a } + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +acmDebugLog :: String -> AcM s () +acmDebugLog s = AcM (hPutStrLn stderr s) + +data V t = V (STy t) (Rep t) + +interpret :: Ex '[] t -> Rep t +interpret = interpretOpen False SNil SNil + +-- | Bool: whether to trace execution with debug prints (very verbose) +interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t +interpretOpen prints env venv e = + runAcM $ + let ?depth = 0 + ?prints = prints + in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e + +interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) + => SList V env -> Ex env t -> AcM s (Rep t) +interpret' env e = do + let tenv = slistMap (\(V t _) -> t) env + let dep = ?depth + let lenlimit = max 20 (100 - dep) + let replace a b = map (\c -> if c == a then b else c) + let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." + | otherwise = replace '\n' ' ' s + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) + res <- let ?depth = dep + 1 in interpret'Rec env e + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" + return res + +interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) +interpret'Rec env = \case + EVar _ _ i -> case slistIdx env i of V _ x -> return x + ELet _ a b -> do + x <- interpret' env a + let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b + expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined + EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b + EFst _ e -> fst <$> interpret' env e + ESnd _ e -> snd <$> interpret' env e + ENil _ -> return () + EInl _ _ e -> Left <$> interpret' env e + EInr _ _ e -> Right <$> interpret' env e + ECase _ e a b -> + let STEither t1 t2 = typeOf e + in interpret' env e >>= \case + Left x -> interpret' (V t1 x `SCons` env) a + Right y -> interpret' (V t2 y `SCons` env) b + ENothing _ _ -> return Nothing + EJust _ e -> Just <$> interpret' env e + EMaybe _ a b e -> + let STMaybe t1 = typeOf e + in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e + ELNil _ _ _ -> return Nothing + ELInl _ _ e -> Just . Left <$> interpret' env e + ELInr _ _ e -> Just . Right <$> interpret' env e + ELCase _ e a b c -> + let STLEither t1 t2 = typeOf e + in interpret' env e >>= \case + Nothing -> interpret' env a + Just (Left x) -> interpret' (V t1 x `SCons` env) b + Just (Right y) -> interpret' (V t2 y `SCons` env) c + EConstArr _ _ _ v -> return v + EBuild _ dim a b -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) + EMap _ a b -> do + let STArr _ t = typeOf b + arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b + EFold1Inner _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + ESum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) + EReplicate1Inner _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arr <- interpret' env b + let sh = arrayShape arr + return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) + EMaximum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EMinimum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr + EZip _ a b -> do + arr1 <- interpret' env a + arr2 <- interpret' env b + let sh = arrayShape arr1 + when (sh /= arrayShape arr2) $ + error "Interpreter: mismatched shapes in EZip" + return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EConst _ _ v -> return v + EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e + EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) + EIdx _ a b -> + let STArr n _ = typeOf a + in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp op <$> interpret' env e + ECustom _ t1 t2 _ pr _ _ e1 e2 -> do + e1' <- interpret' env e1 + e2' <- interpret' env e2 + interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr + ERecompute _ e -> interpret' env e + EWith _ t e1 e2 -> do + initval <- interpret' env e1 + withAccum t (typeOf e2) initval $ \accum -> + interpret' (V (STAccum t) accum `SCons` env) e2 + EAccum _ t p e1 sp e2 e3 -> do + idx <- interpret' env e1 + val <- interpret' env e2 + accum <- interpret' env e3 + accumAddSparseD t p accum idx sp val + EZero _ t ezi -> do + zi <- interpret' env ezi + return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi + EPlus _ t a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ addM t a' b' + EOneHot _ t p a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ onehotM p t a' b' + EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s + +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of + OAdd st -> numericIsNum st $ uncurry (+) arg + OMul st -> numericIsNum st $ uncurry (*) arg + ONeg st -> numericIsNum st $ negate arg + OLt st -> numericIsNum st $ uncurry (<) arg + OLe st -> numericIsNum st $ uncurry (<=) arg + OEq st -> styIsEq st $ uncurry (==) arg + ONot -> not arg + OAnd -> uncurry (&&) arg + OOr -> uncurry (||) arg + OIf -> if arg then Left () else Right () + ORound64 -> round arg + OToFl64 -> fromIntegral arg + ORecip st -> floatingIsFractional st $ recip arg + OExp st -> floatingIsFractional st $ exp arg + OLog st -> floatingIsFractional st $ log arg + OIDiv st -> integralIsIntegral st $ uncurry quot arg + OMod st -> integralIsIntegral st $ uncurry rem arg + where + styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r + styIsEq STI32 = id + styIsEq STI64 = id + styIsEq STF32 = id + styIsEq STF64 = id + styIsEq STBool = id + +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) + SMTLEither _ _ -> Nothing + SMTMaybe _ -> Nothing + SMTArr _ t -> arrayMap (zeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of + SMTNil -> () + SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) + SMTLEither t1 t2 -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) + _ -> error "Plus of inconsistent LEithers" + SMTMaybe t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> Just (addM t x y) + SMTArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + SMTScal sty -> numericIsNum sty $ a + b + +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx + +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum t _ initval f = AcM $ do + accum <- newAcDense t initval + out <- unAcM $ f accum + val <- readAc t accum + return (out, val) + +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val + SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val + SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val + SMTArr _ t1 -> arrayMapM (newAcDense t1) val + SMTScal _ -> newIORef val + +onehotArray :: Monad m + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" + -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = + let arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ziarr + !linindex = toLinearIndex arrsh arrindex + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) + +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val + SMTMaybe t -> traverse (readAc t) =<< readIORef val + SMTArr _ t -> traverse (readAc t) val + SMTScal _ -> readIORef val + +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val + +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +-- TODO: makeval is always 'error' now. Simplify? +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () +realiseMaybeSparse ref makeval modifyval = + -- Try modifying what's already in ref. The 'join' makes the snd + -- of the function's return value a _continuation_ that is run after + -- the critical section ends. + AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (ac, do val <- makeval + join $ atomicModifyIORef' ref $ \ac' -> case ac' of + Nothing -> (Just val, return ()) + Just val' -> (ac', unAcM $ modifyval val')) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just val -> (ac, unAcM $ modifyval val) + + +numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r +numericIsNum STI32 = id +numericIsNum STI64 = id +numericIsNum STF32 = id +numericIsNum STF64 = id + +floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r +floatingIsFractional STF32 = id +floatingIsFractional STF64 = id + +integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r +integralIsIntegral STI32 = id +integralIsIntegral STI64 = id + +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _ SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i + +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _ SZ _ = () +tupRepIdx uncons (SS n) tup = + let (tup', i) = uncons tup + in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i + +ixUncons :: Index (S n) -> (Index n, Int) +ixUncons (IxCons idx i) = (idx, i) + +shUncons :: Shape (S n) -> (Shape n, Int) +shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y |
