aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Interpreter.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Interpreter.hs')
-rw-r--r--src/CHAD/Interpreter.hs471
1 files changed, 471 insertions, 0 deletions
diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs
new file mode 100644
index 0000000..a9421e6
--- /dev/null
+++ b/src/CHAD/Interpreter.hs
@@ -0,0 +1,471 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Interpreter (
+ interpret,
+ interpretOpen,
+ Value(..),
+) where
+
+import Control.Monad (foldM, join, when, forM_)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict (runStateT, get, put)
+import Data.Bifunctor (bimap)
+import Data.Bitraversable (bitraverse)
+import Data.Char (isSpace)
+import Data.Functor.Identity
+import qualified Data.Functor.Product as Product
+import Data.Int (Int64)
+import Data.IORef
+import Data.Tuple (swap)
+import System.IO (hPutStrLn, stderr)
+import System.IO.Unsafe (unsafePerformIO)
+
+import Debug.Trace
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Pretty
+import CHAD.AST.Sparse.Types
+import CHAD.Data
+import CHAD.Interpreter.Rep
+
+
+newtype AcM s a = AcM { unAcM :: IO a }
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
+
+acmDebugLog :: String -> AcM s ()
+acmDebugLog s = AcM (hPutStrLn stderr s)
+
+data V t = V (STy t) (Rep t)
+
+interpret :: Ex '[] t -> Rep t
+interpret = interpretOpen False SNil SNil
+
+-- | Bool: whether to trace execution with debug prints (very verbose)
+interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env venv e =
+ runAcM $
+ let ?depth = 0
+ ?prints = prints
+ in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e
+
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int)
+ => SList V env -> Ex env t -> AcM s (Rep t)
+interpret' env e = do
+ let tenv = slistMap (\(V t _) -> t) env
+ let dep = ?depth
+ let lenlimit = max 20 (100 - dep)
+ let replace a b = map (\c -> if c == a then b else c)
+ let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
+ | otherwise = replace '\n' ' ' s
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e)
+ res <- let ?depth = dep + 1 in interpret'Rec env e
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
+ return res
+
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t)
+interpret'Rec env = \case
+ EVar _ _ i -> case slistIdx env i of V _ x -> return x
+ ELet _ a b -> do
+ x <- interpret' env a
+ let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b
+ expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
+ EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
+ EFst _ e -> fst <$> interpret' env e
+ ESnd _ e -> snd <$> interpret' env e
+ ENil _ -> return ()
+ EInl _ _ e -> Left <$> interpret' env e
+ EInr _ _ e -> Right <$> interpret' env e
+ ECase _ e a b ->
+ let STEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Left x -> interpret' (V t1 x `SCons` env) a
+ Right y -> interpret' (V t2 y `SCons` env) b
+ ENothing _ _ -> return Nothing
+ EJust _ e -> Just <$> interpret' env e
+ EMaybe _ a b e ->
+ let STMaybe t1 = typeOf e
+ in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
+ ELNil _ _ _ -> return Nothing
+ ELInl _ _ e -> Just . Left <$> interpret' env e
+ ELInr _ _ e -> Just . Right <$> interpret' env e
+ ELCase _ e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Nothing -> interpret' env a
+ Just (Left x) -> interpret' (V t1 x `SCons` env) b
+ Just (Right y) -> interpret' (V t2 y `SCons` env) c
+ EConstArr _ _ _ v -> return v
+ EBuild _ dim a b -> do
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
+ arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ EMap _ a b -> do
+ let STArr _ t = typeOf b
+ arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b
+ EFold1Inner _ _ a b c -> do
+ let t = typeOf b
+ let f = \x -> interpret' (V (STPair t t) x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ ESum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
+ EReplicate1Inner _ a b -> do
+ n <- fromIntegral @Int64 @Int <$> interpret' env a
+ arr <- interpret' env b
+ let sh = arrayShape arr
+ return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx)
+ EMaximum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $
+ arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EMinimum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $
+ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EReshape _ dim esh e -> do
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh
+ arr <- interpret' env e
+ return $ arrayReshape sh arr
+ EZip _ a b -> do
+ arr1 <- interpret' env a
+ arr2 <- interpret' env b
+ let sh = arrayShape arr1
+ when (sh /= arrayShape arr2) $
+ error "Interpreter: mismatched shapes in EZip"
+ return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i))
+ EFold1InnerD1 _ _ a b c -> do
+ let t = typeOf b
+ let f = \x -> interpret' (V (STPair t t) x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ -- TODO: this is very inefficient, even for an interpreter; with mutable
+ -- arrays this can be a lot better with no lists
+ res <- arrayGenerateM sh $ \idx -> do
+ (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ return (y, arrayFromList (ShNil `ShCons` n) stores)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EFold1InnerD2 _ _ ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef
+ bog <- interpret' env ebog
+ arrctg <- interpret' env ed
+ let sh `ShCons` n = arrayShape bog
+ when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2"
+ res <- arrayGenerateM sh $ \idx -> do
+ let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
+ loop i !ctg !inpctgs = do
+ let b = arrayIndex bog (idx `IxCons` i)
+ (ctg1, ctg2) <- f b ctg
+ loop (i - 1) ctg1 (ctg2 : inpctgs)
+ (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
+ return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EConst _ _ v -> return v
+ EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
+ EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
+ EIdx _ a b ->
+ let STArr n _ = typeOf a
+ in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
+ EOp _ op e -> interpretOp op <$> interpret' env e
+ ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
+ e1' <- interpret' env e1
+ e2' <- interpret' env e2
+ interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
+ ERecompute _ e -> interpret' env e
+ EWith _ t e1 e2 -> do
+ initval <- interpret' env e1
+ withAccum t (typeOf e2) initval $ \accum ->
+ interpret' (V (STAccum t) accum `SCons` env) e2
+ EAccum _ t p e1 sp e2 e3 -> do
+ idx <- interpret' env e1
+ val <- interpret' env e2
+ accum <- interpret' env e3
+ accumAddSparseD t p accum idx sp val
+ EZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
+ EPlus _ t a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ addM t a' b'
+ EOneHot _ t p a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ onehotM p t a' b'
+ EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
+
+interpretOp :: SOp a t -> Rep a -> Rep t
+interpretOp op arg = case op of
+ OAdd st -> numericIsNum st $ uncurry (+) arg
+ OMul st -> numericIsNum st $ uncurry (*) arg
+ ONeg st -> numericIsNum st $ negate arg
+ OLt st -> numericIsNum st $ uncurry (<) arg
+ OLe st -> numericIsNum st $ uncurry (<=) arg
+ OEq st -> styIsEq st $ uncurry (==) arg
+ ONot -> not arg
+ OAnd -> uncurry (&&) arg
+ OOr -> uncurry (||) arg
+ OIf -> if arg then Left () else Right ()
+ ORound64 -> round arg
+ OToFl64 -> fromIntegral arg
+ ORecip st -> floatingIsFractional st $ recip arg
+ OExp st -> floatingIsFractional st $ exp arg
+ OLog st -> floatingIsFractional st $ log arg
+ OIDiv st -> integralIsIntegral st $ uncurry quot arg
+ OMod st -> integralIsIntegral st $ uncurry rem arg
+ where
+ styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
+ styIsEq STI32 = id
+ styIsEq STI64 = id
+ styIsEq STF32 = id
+ styIsEq STF64 = id
+ styIsEq STBool = id
+
+zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t
+zeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi))
+ SMTLEither _ _ -> Nothing
+ SMTMaybe _ -> Nothing
+ SMTArr _ t -> arrayMap (zeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
+addM :: SMTy t -> Rep t -> Rep t -> Rep t
+addM typ a b = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b))
+ SMTLEither t1 t2 -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y))
+ (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y))
+ _ -> error "Plus of inconsistent LEithers"
+ SMTMaybe t -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just x, Just y) -> Just (addM t x y)
+ SMTArr _ t ->
+ let sh1 = arrayShape a
+ sh2 = arrayShape b
+ in if | shapeSize sh1 == 0 -> b
+ | shapeSize sh2 == 0 -> a
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
+ SMTScal sty -> numericIsNum sty $ a + b
+
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
+onehotM SAPHere _ _ val = val
+onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
+onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
+onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val))
+onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val))
+onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val)
+onehotM (SAPArrIdx prj) (SMTArr n a) idx val =
+ runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx
+
+withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+withAccum t _ initval f = AcM $ do
+ accum <- newAcDense t initval
+ out <- unAcM $ f accum
+ val <- readAc t accum
+ return (out, val)
+
+newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
+newAcDense typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val
+ SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val
+ SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val
+ SMTArr _ t1 -> arrayMapM (newAcDense t1) val
+ SMTScal _ -> newIORef val
+
+onehotArray :: Monad m
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
+ -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
+onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
+ let arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ziarr
+ !linindex = toLinearIndex arrsh arrindex
+ in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
+
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
+ SMTScal _ -> readIORef val
+
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
+ (SMTLEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
+
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
+
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let (arrindex', idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ref
+ linindex = toLinearIndex arrsh arrindex
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ case val of
+ Nothing -> return ()
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+-- TODO: makeval is always 'error' now. Simplify?
+realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
+realiseMaybeSparse ref makeval modifyval =
+ -- Try modifying what's already in ref. The 'join' makes the snd
+ -- of the function's return value a _continuation_ that is run after
+ -- the critical section ends.
+ AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (ac, do val <- makeval
+ join $ atomicModifyIORef' ref $ \ac' -> case ac' of
+ Nothing -> (Just val, return ())
+ Just val' -> (ac', unAcM $ modifyval val'))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just val -> (ac, unAcM $ modifyval val)
+
+
+numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
+numericIsNum STI32 = id
+numericIsNum STI64 = id
+numericIsNum STF32 = id
+numericIsNum STF64 = id
+
+floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r
+floatingIsFractional STF32 = id
+floatingIsFractional STF64 = id
+
+integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r
+integralIsIntegral STI32 = id
+integralIsIntegral STI64 = id
+
+unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
+ -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
+unTupRepIdx nil _ SZ _ = nil
+unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
+
+tupRepIdx :: (forall m. f (S m) -> (f m, Int))
+ -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
+tupRepIdx _ SZ _ = ()
+tupRepIdx uncons (SS n) tup =
+ let (tup', i) = uncons tup
+ in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i
+
+ixUncons :: Index (S n) -> (Index n, Int)
+ixUncons (IxCons idx i) = (idx, i)
+
+shUncons :: Shape (S n) -> (Shape n, Int)
+shUncons (ShCons idx i) = (idx, i)
+
+mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
+mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f'
+ where f' x = do
+ s <- get
+ (s', y) <- lift $ f s x
+ put s'
+ return y