{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Shape where

import Control.DeepSeq (NFData(..))
import Data.Bifunctor (first)
import Data.Coerce
import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Kind (Type, Constraint)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits

import Data.Array.Mixed.Types


-- | The length of a type-level list. If the argument is a shape, then the
-- result is the rank of that shape.
type family Rank sh where
  Rank '[] = 0
  Rank (_ : sh) = Rank sh + 1


-- * Mixed lists

type role ListX nominal representational
type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
data ListX sh f where
  ZX :: ListX '[] f
  (::%) :: f n -> ListX sh f -> ListX (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
infixr 3 ::%

instance (forall n. Show (f n)) => Show (ListX sh f) where
  showsPrec _ = listxShow shows

instance (forall n. NFData (f n)) => NFData (ListX sh f) where
  rnf ZX = ()
  rnf (x ::% l) = rnf x `seq` rnf l

data UnconsListXRes f sh1 =
  forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
listxUncons ZX = Nothing

-- | This checks only whether the types are equal; if the elements of the list
-- are not singletons, their values may still differ. This corresponds to
-- 'testEquality', except on the penultimate type parameter.
listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
listxEqType ZX ZX = Just Refl
listxEqType (n ::% sh) (m ::% sh')
  | Just Refl <- testEquality n m
  , Just Refl <- listxEqType sh sh'
  = Just Refl
listxEqType _ _ = Nothing

-- | This checks whether the two lists actually contain equal values. This is
-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
-- in the @some@ package (except on the penultimate type parameter).
listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
listxEqual ZX ZX = Just Refl
listxEqual (n ::% sh) (m ::% sh')
  | Just Refl <- testEquality n m
  , n == m
  , Just Refl <- listxEqual sh sh'
  = Just Refl
listxEqual _ _ = Nothing

listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
listxFmap _ ZX = ZX
listxFmap f (x ::% xs) = f x ::% listxFmap f xs

listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
listxFold _ ZX = mempty
listxFold f (x ::% xs) = f x <> listxFold f xs

listxLength :: ListX sh f -> Int
listxLength = getSum . listxFold (\_ -> Sum 1)

listxRank :: ListX sh f -> SNat (Rank sh)
listxRank ZX = SNat
listxRank (_ ::% l) | SNat <- listxRank l = SNat

listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListX sh' f -> ShowS
    go _ ZX = id
    go prefix (x ::% xs) = showString prefix . f x . go "," xs

listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i)
listxFromList topssh topl = go topssh topl
  where
    go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
    go ZKX [] = ZX
    go (_ :!% sh) (i : is) = Const i ::% go sh is
    go _ _ = error $ "listxFromList: Mismatched list length (type says "
                       ++ show (ssxLength topssh) ++ ", list has length "
                       ++ show (length topl) ++ ")"

listxToList :: ListX sh' (Const i) -> [i]
listxToList ZX = []
listxToList (Const i ::% is) = i : listxToList is

listxTail :: ListX (n : sh) i -> ListX sh i
listxTail (_ ::% sh) = sh

listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'

listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
listxDrop long ZX = long
listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short

listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
listxInit (_ ::% ZX) = ZX

listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
listxLast (x ::% ZX) = x


-- * Mixed indices

-- | This is a newtype over 'ListX'.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
  deriving (Eq, Ord, Generic)

pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
pattern ZIX = IxX ZX

pattern (:.%)
  :: forall {sh1} {i}.
     forall n sh. (n : sh ~ sh1)
  => i -> IxX sh i -> IxX sh1 i
pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
  where i :.% IxX shl = IxX (Const i ::% shl)
infixr 3 :.%

{-# COMPLETE ZIX, (:.%) #-}

type IIxX sh = IxX sh Int

instance Show i => Show (IxX sh i) where
  showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l

instance Functor (IxX sh) where
  fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)

instance Foldable (IxX sh) where
  foldMap f (IxX l) = listxFold (f . getConst) l

instance NFData i => NFData (IxX sh i)

ixxZero :: StaticShX sh -> IIxX sh
ixxZero ZKX = ZIX
ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh

ixxZero' :: IShX sh -> IIxX sh
ixxZero' ZSX = ZIX
ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh

ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)

ixxTail :: IxX (n : sh) i -> IxX sh i
ixxTail (IxX list) = IxX (listxTail list)

ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
ixxAppend = coerce (listxAppend @_ @(Const i))

ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
ixxDrop = coerce (listxDrop @(Const i) @(Const i))

ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
ixxInit = coerce (listxInit @(Const i))

ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))

ixxFromLinear :: IShX sh -> Int -> IIxX sh
ixxFromLinear = \sh i -> case go sh i of
  (idx, 0) -> idx
  _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
               " in array of shape " ++ show sh ++ ")"
  where
    -- returns (index in subarray, remaining index in enclosing array)
    go :: IShX sh -> Int -> (IIxX sh, Int)
    go ZSX i = (ZIX, i)
    go (n :$% sh) i =
      let (idx, i') = go sh i
          (upi, locali) = i' `quotRem` fromSMayNat' n
      in (locali :.% idx, upi)

ixxToLinear :: IShX sh -> IIxX sh -> Int
ixxToLinear = \sh i -> fst (go sh i)
  where
    -- returns (index in subarray, size of subarray)
    go :: IShX sh -> IIxX sh -> (Int, Int)
    go ZSX ZIX = (0, 1)
    go (n :$% sh) (i :.% ix) =
      let (lidx, sz) = go sh ix
      in (sz * i + lidx, fromSMayNat' n * sz)


-- * Mixed shapes

data SMayNat i f n where
  SUnknown :: i -> SMayNat i f Nothing
  SKnown :: f n -> SMayNat i f (Just n)
deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)

instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
  rnf (SUnknown i) = rnf i
  rnf (SKnown x) = rnf x

instance TestEquality f => TestEquality (SMayNat i f) where
  testEquality SUnknown{} SUnknown{} = Just Refl
  testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
  testEquality _ _ = Nothing

fromSMayNat :: (n ~ Nothing => i -> r)
            -> (forall m. n ~ Just m => f m -> r)
            -> SMayNat i f n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s

fromSMayNat' :: SMayNat Int SNat n -> Int
fromSMayNat' = fromSMayNat id fromSNat'

type family AddMaybe n m where
  AddMaybe Nothing _ = Nothing
  AddMaybe (Just _) Nothing = Nothing
  AddMaybe (Just n) (Just m) = Just (n + m)

smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)


-- | This is a newtype over 'ListX'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
  deriving (Eq, Ord, Generic)

pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
pattern ZSX = ShX ZX

pattern (:$%)
  :: forall {sh1} {i}.
     forall n sh. (n : sh ~ sh1)
  => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
  where i :$% ShX shl = ShX (i ::% shl)
infixr 3 :$%

{-# COMPLETE ZSX, (:$%) #-}

type IShX sh = ShX sh Int

instance Show i => Show (ShX sh i) where
  showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l

instance Functor (ShX sh) where
  fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)

instance NFData i => NFData (ShX sh i) where
  rnf (ShX ZX) = ()
  rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
  rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)

shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l

shxRank :: ShX sh i -> SNat (Rank sh)
shxRank (ShX list) = listxRank list

-- | This checks only whether the types are equal; unknown dimensions might
-- still differ. This corresponds to 'testEquality', except on the penultimate
-- type parameter.
shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqType ZSX ZSX = Just Refl
shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
  | Just Refl <- sameNat n m
  , Just Refl <- shxEqType sh sh'
  = Just Refl
shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
  | Just Refl <- shxEqType sh sh'
  = Just Refl
shxEqType _ _ = Nothing

-- | This checks whether all dimensions have the same value. This is more than
-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
-- @some@ package (except on the penultimate type parameter).
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
  | Just Refl <- sameNat n m
  , Just Refl <- shxEqual sh sh'
  = Just Refl
shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
  | i == j
  , Just Refl <- shxEqual sh sh'
  = Just Refl
shxEqual _ _ = Nothing

-- | The number of elements in an array described by this shape.
shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh

shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
shxFromList topssh topl = go topssh topl
  where
    go :: StaticShX sh' -> [Int] -> ShX sh' Int
    go ZKX [] = ZSX
    go (SKnown sn :!% sh) (i : is)
      | i == fromSNat' sn = SKnown sn :$% go sh is
      | otherwise = error $ "shxFromList: Value does not match typing (type says "
                              ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
    go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is
    go _ _ = error $ "shxFromList: Mismatched list length (type says "
                       ++ show (ssxLength topssh) ++ ", list has length "
                       ++ show (length topl) ++ ")"

shxToList :: IShX sh -> [Int]
shxToList ZSX = []
shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh

shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))

shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)

shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))

shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))

shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))

shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
shxInit = coerce (listxInit @(SMayNat i SNat))

shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
shxLast = coerce (listxLast @(SMayNat i SNat))

shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
shxTakeSSX _ = flip go
  where
    go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
    go ZKX _ = ZSX
    go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh

-- This is a weird operation, so it has a long name
shxCompleteZeros :: StaticShX sh -> IShX sh
shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh

shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)

shxEnum :: IShX sh -> [IIxX sh]
shxEnum = \sh -> go sh id []
  where
    go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
    go ZSX f = (f ZIX :)
    go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]

shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
shxCast ZSX ZKX = Just ZSX
shxCast (SKnown n   :$% sh) (SKnown m    :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
shxCast (SUnknown n :$% sh) (SKnown m    :!% ssh) | n == fromSNat' m              = (SKnown m :$%) <$> shxCast sh ssh
shxCast (SKnown n   :$% sh) (SUnknown () :!% ssh)                                 = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh)                                 = (SUnknown n :$%) <$> shxCast sh ssh
shxCast _ _ = Nothing

-- | Partial version of 'shxCast'.
shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
shxCast' sh ssh = case shxCast sh ssh of
  Just sh' -> sh'
  Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"


-- * Static mixed shapes

-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
type StaticShX :: [Maybe Nat] -> Type
newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
  deriving (Eq, Ord)

pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
pattern ZKX = StaticShX ZX

pattern (:!%)
  :: forall {sh1}.
     forall n sh. (n : sh ~ sh1)
  => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
  where i :!% StaticShX shl = StaticShX (i ::% shl)
infixr 3 :!%

{-# COMPLETE ZKX, (:!%) #-}

instance Show (StaticShX sh) where
  showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l

instance TestEquality StaticShX where
  testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2

ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l

-- | @ssxEqType = 'testEquality'@. Provided for consistency.
ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
ssxEqType = testEquality

ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'

ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh

ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))

ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
ssxInit = coerce (listxInit @(SMayNat () SNat))

ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
ssxLast = coerce (listxLast @(SMayNat () SNat))

-- | This may fail if @sh@ has @Nothing@s in it.
ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
ssxToShX' ZKX = Just ZSX
ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
ssxToShX' (SUnknown _ :!% _) = Nothing

ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
  | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
  = SUnknown () :!% ssxReplicate n

ssxIotaFrom :: Int -> StaticShX sh -> [Int]
ssxIotaFrom _ ZKX = []
ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh

ssxFromShape :: IShX sh -> StaticShX sh
ssxFromShape ZSX = ZKX
ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh

ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n


-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
type KnownShX :: [Maybe Nat] -> Constraint
class KnownShX sh where knownShX :: StaticShX sh
instance KnownShX '[] where knownShX = ZKX
instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX

withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX sh = withDict @(KnownShX sh) sh


-- * Flattening

type Flatten sh = Flatten' 1 sh

type family Flatten' acc sh where
  Flatten' acc '[] = Just acc
  Flatten' acc (Nothing : sh) = Nothing
  Flatten' acc (Just n : sh) = Flatten' (acc * n) sh

-- This function is currently unused
ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
ssxFlatten = go (SNat @1)
  where
    go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
    go acc ZKX = SKnown acc
    go _ (SUnknown () :!% _) = SUnknown ()
    go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh

shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
shxFlatten = go (SNat @1)
  where
    go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
    go acc ZSX = SKnown acc
    go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
    go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh

    goUnknown :: Int -> IShX sh -> Int
    goUnknown acc ZSX = acc
    goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
    goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh


-- | Very untyped: only length is checked (at runtime).
instance KnownShX sh => IsList (ListX sh (Const i)) where
  type Item (ListX sh (Const i)) = i
  fromList = listxFromList (knownShX @sh)
  toList = listxToList

-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShX sh => IsList (IxX sh i) where
  type Item (IxX sh i) = i
  fromList = IxX . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length and known dimensions are checked (at runtime).
instance KnownShX sh => IsList (ShX sh Int) where
  type Item (ShX sh Int) = Int
  fromList = shxFromList (knownShX @sh)
  toList = shxToList