{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Shape where

import Data.Array.Shape qualified as O
import Data.Array.Mixed.Types
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Kind (Type, Constraint)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN

import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape


type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
  ZR :: ListR 0 i
  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
deriving instance Functor (ListR n)
deriving instance Foldable (ListR n)
infixr 3 :::

instance Show i => Show (ListR n i) where
  showsPrec _ = listrShow shows

data UnconsListRRes i n1 =
  forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
listrUncons ZR = Nothing

-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
listrEqRank ZR ZR = Just Refl
listrEqRank (_ ::: sh) (_ ::: sh')
  | Just Refl <- listrEqRank sh sh'
  = Just Refl
listrEqRank _ _ = Nothing

-- | This compares the lists for value equality.
listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
listrEqual ZR ZR = Just Refl
listrEqual (i ::: sh) (j ::: sh')
  | Just Refl <- listrEqual sh sh'
  , i == j
  = Just Refl
listrEqual _ _ = Nothing

listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListR n' i -> ShowS
    go _ ZR = id
    go prefix (x ::: xs) = showString prefix . f x . go "," xs

listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh

listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
listrFromList [] k = k ZR
listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)

listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
listrHead ZR = error "unreachable"

listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"

listrInit :: ListR (n + 1) i -> ListR n i
listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
listrInit (_ ::: ZR) = ZR
listrInit ZR = error "unreachable"

listrLast :: ListR (n + 1) i -> i
listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
listrLast ZR = error "unreachable"

listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
listrIndex _ ZR = error "k + 1 <= 0"

listrRank :: ListR n i -> SNat n
listrRank ZR = SNat
listrRank (_ ::: sh) = snatSucc (listrRank sh)

listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
  listrFromList perm $ \sperm ->
    case (listrRank sperm, listrRank sh) of
      (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
        LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
        EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
        GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
                       ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
  where
    listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
    listrSplitAt SZ sh = (ZR, sh)
    listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
    listrSplitAt SS{} ZR = error "m' + 1 <= 0"

    applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
    applyPermRFull _ ZR _ = ZR
    applyPermRFull sm@SNat (i ::: perm) l =
      TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
        case cmpNat (SNat @(idx + 1)) sm of
          LTI -> listrIndex si l ::: applyPermRFull sm perm l
          EQI -> listrIndex si l ::: applyPermRFull sm perm l
          GTI -> error "listrPermutePrefix: Index in permutation out of range"


-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
  deriving (Eq, Ord)
  deriving newtype (Functor, Foldable)

pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR

pattern (:.:)
  :: forall {n1} {i}.
     forall n. (n + 1 ~ n1)
  => i -> IxR n i -> IxR n1 i
pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
  where i :.: IxR sh = IxR (i ::: sh)
infixr 3 :.:

{-# COMPLETE ZIR, (:.:) #-}

type IIxR n = IxR n Int

instance Show i => Show (IxR n i) where
  showsPrec _ (IxR l) = listrShow shows l

ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n

ixCvtXR :: IIxX sh -> IIxR (Rank sh)
ixCvtXR ZIX = ZIR
ixCvtXR (n :.% idx) = n :.: ixCvtXR idx

ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: (idx :: IxR m Int)) =
  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
    (n :.% ixCvtRX idx)

ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list

ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)

ixrInit :: IxR (n + 1) i -> IxR n i
ixrInit (IxR list) = IxR (listrInit list)

ixrLast :: IxR (n + 1) i -> i
ixrLast (IxR list) = listrLast list

ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)

ixrRank :: IxR n i -> SNat n
ixrRank (IxR sh) = listrRank sh

ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)


type role ShR nominal representational
type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
  deriving (Eq, Ord)
  deriving newtype (Functor, Foldable)

pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR

pattern (:$:)
  :: forall {n1} {i}.
     forall n. (n + 1 ~ n1)
  => i -> ShR n i -> ShR n1 i
pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
  where i :$: ShR sh = ShR (i ::: sh)
infixr 3 :$:

{-# COMPLETE ZSR, (:$:) #-}

type IShR n = ShR n Int

instance Show i => Show (ShR n i) where
  showsPrec _ (ShR l) = listrShow shows l

shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
shCvtXR' ZSX =
  castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
    ZSR
shCvtXR' (n :$% (idx :: IShX sh))
  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
  castWith (subst2 (lem1 @sh Refl))
    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
  where
    lem1 :: forall sh' n' k.
            k : sh' :~: Replicate n' Nothing
         -> Rank sh' + 1 :~: n'
    lem1 Refl = unsafeCoerceRefl

    lem2 :: k : sh :~: Replicate n Nothing
         -> sh :~: Replicate (Rank sh) Nothing
    lem2 Refl = unsafeCoerceRefl

shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
shCvtRX (n :$: (idx :: ShR m Int)) =
  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
    (SUnknown n :$% shCvtRX idx)

-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'

-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'

-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh

shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list

shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)

shrInit :: ShR (n + 1) i -> ShR n i
shrInit (ShR list) = ShR (listrInit list)

shrLast :: ShR (n + 1) i -> i
shrLast (ShR list) = listrLast list

shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)

-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
shrRank :: ShR n i -> SNat n
shrRank (ShR sh) = listrRank sh

shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)


-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ListR n i) where
  type Item (ListR n i) = i
  fromList topl = go (SNat @n) topl
    where
      go :: SNat n' -> [i] -> ListR n' i
      go SZ [] = ZR
      go (SS n) (i : is) = i ::: go n is
      go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
                         ++ show (fromSNat (SNat @n)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = Foldable.toList

-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (IxR n i) where
  type Item (IxR n i) = i
  fromList = IxR . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ShR n i) where
  type Item (ShR n i) = i
  fromList = ShR . IsList.fromList
  toList = Foldable.toList


type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
  ZS :: ListS '[] f
  -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
  (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
infixr 3 ::$

instance (forall n. Show (f n)) => Show (ListS sh f) where
  showsPrec _ = listsShow shows

data UnconsListSRes f sh1 =
  forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing

-- | This checks only whether the types are equal; if the elements of the list
-- are not singletons, their values may still differ. This corresponds to
-- 'testEquality', except on the penultimate type parameter.
listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
listsEqType ZS ZS = Just Refl
listsEqType (n ::$ sh) (m ::$ sh')
  | Just Refl <- testEquality n m
  , Just Refl <- listsEqType sh sh'
  = Just Refl
listsEqType _ _ = Nothing

-- | This checks whether the two lists actually contain equal values. This is
-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
-- in the @some@ package (except on the penultimate type parameter).
listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
listsEqual ZS ZS = Just Refl
listsEqual (n ::$ sh) (m ::$ sh')
  | Just Refl <- testEquality n m
  , n == m
  , Just Refl <- listsEqual sh sh'
  = Just Refl
listsEqual _ _ = Nothing

listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs

listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
listsFold _ ZS = mempty
listsFold f (x ::$ xs) = f x <> listsFold f xs

listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListS sh' f -> ShowS
    go _ ZS = id
    go prefix (x ::$ xs) = showString prefix . f x . go "," xs

listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is

listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i

listsTail :: ListS (n : sh) f -> ListS sh f
listsTail (_ ::$ sh) = sh

listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
listsInit (_ ::$ ZS) = ZS

listsLast :: ListS (n : sh) f -> f (Last (n : sh))
listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
listsLast (n ::$ ZS) = n

listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'

listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)

listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"

listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
listsDropLenPerm PNil sh = sh
listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"

listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
  case listsIndex (Proxy @is') (Proxy @sh) i sh of
    (item, SNat) -> item ::$ listsPermute is sh

-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
listsIndex _ _ SZ (n ::$ _) = (n, SNat)
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
  = listsIndex p pT i sh
listsIndex _ _ _ ZS = error "Index into empty shape"

listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)


-- | An index into a shape-typed array.
--
-- For convenience, this contains regular 'Int's instead of bounded integers
-- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
  deriving (Eq, Ord)

pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS

pattern (:.$)
  :: forall {sh1} {i}.
     forall n sh. (KnownNat n, n : sh ~ sh1)
  => i -> IxS sh i -> IxS sh1 i
pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
  where i :.$ IxS shl = IxS (Const i ::$ shl)
infixr 3 :.$

{-# COMPLETE ZIS, (:.$) #-}

type IIxS sh = IxS sh Int

instance Show i => Show (IxS sh i) where
  showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l

instance Functor (IxS sh) where
  fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)

instance Foldable (IxS sh) where
  foldMap f (IxS l) = listsFold (f . getConst) l

ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh

ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ZSS ZIX = ZIS
ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx

ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh

ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)

ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)

ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
ixsInit (IxS list) = IxS (listsInit list)

ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)

ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))

ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))


-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
-- Note that because the shape of a shape-typed array is known statically, you
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
newtype ShS sh = ShS (ListS sh SNat)
  deriving (Eq, Ord)

pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
pattern ZSS = ShS ZS

pattern (:$$)
  :: forall {sh1}.
     forall n sh. (KnownNat n, n : sh ~ sh1)
  => SNat n -> ShS sh -> ShS sh1
pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
  where i :$$ ShS shl = ShS (i ::$ shl)

infixr 3 :$$

{-# COMPLETE ZSS, (:$$) #-}

instance Show (ShS sh) where
  showsPrec _ (ShS l) = listsShow (shows . fromSNat) l

instance TestEquality ShS where
  testEquality (ShS l1) (ShS l2) = listsEqType l1 l2

-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality

shsLength :: ShS sh -> Int
shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)

shsRank :: ShS sh -> SNat (Rank sh)
shsRank (ShS l) = listsRank l

shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh

shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
  castWith (subst1 (lem Refl)) $
    n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
                                 idx)
  where
    lem :: forall sh1 sh' n.
           Just n : sh1 :~: MapJust sh'
        -> n : Tail sh' :~: sh'
    lem Refl = unsafeCoerceRefl
shCvtXS' (SUnknown _ :$% _) = error "impossible"

shCvtSX :: ShS sh -> IShX (MapJust sh)
shCvtSX ZSS = ZSX
shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh

shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list

shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)

shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
shsInit (ShS list) = ShS (listsInit list)

shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
shsLast (ShS list) = listsLast list

shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
shsAppend = coerce (listsAppend @_ @SNat)

shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh

shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
shsTakeLen = coerce (listsTakeLenPerm @SNat)

shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)

shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))

shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
shsPermutePrefix = coerce (listsPermutePrefix @SNat)

type family Product sh where
  Product '[] = 1
  Product (n : ns) = n * Product ns

shsProduct :: ShS sh -> SNat (Product sh)
shsProduct ZSS = SNat
shsProduct (n :$$ sh) = n `snatMul` shsProduct sh

-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
type KnownShS :: [Nat] -> Constraint
class KnownShS sh where knownShS :: ShS sh
instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS

withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS sh = withDict @(KnownShS sh) sh

shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict

shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict


-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
  type Item (ListS sh (Const i)) = i
  fromList topl = go (knownShS @sh) topl
    where
      go :: ShS sh' -> [i] -> ListS sh' (Const i)
      go ZSS [] = ZS
      go (_ :$$ sh) (i : is) = Const i ::$ go sh is
      go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
                         ++ show (shsLength (knownShS @sh)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = listsToList

-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShS sh => IsList (IxS sh i) where
  type Item (IxS sh i) = i
  fromList = IxS . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
  type Item (ShS sh) = Int
  fromList topl = ShS (go (knownShS @sh) topl)
    where
      go :: ShS sh' -> [Int] -> ListS sh' SNat
      go ZSS [] = ZS
      go (sn :$$ sh) (i : is)
        | i == fromSNat' sn = sn ::$ go sh is
        | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
      go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
                         ++ show (shsLength (knownShS @sh)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = shsToList