aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs611
1 files changed, 611 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
new file mode 100644
index 0000000..5f4775c
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -0,0 +1,611 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Mixed.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Bifunctor (first)
+import Data.Coerce
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Functor.Product
+import Data.Kind (Constraint, Type)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Exts (withDict)
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+
+import Data.Array.Mixed.Types
+
+
+-- | The length of a type-level list. If the argument is a shape, then the
+-- result is the rank of that shape.
+type family Rank sh where
+ Rank '[] = 0
+ Rank (_ : sh) = Rank sh + 1
+
+
+-- * Mixed lists
+
+type role ListX nominal representational
+type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
+data ListX sh f where
+ ZX :: ListX '[] f
+ (::%) :: f n -> ListX sh f -> ListX (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
+infixr 3 ::%
+
+instance (forall n. Show (f n)) => Show (ListX sh f) where
+ showsPrec _ = listxShow shows
+
+instance (forall n. NFData (f n)) => NFData (ListX sh f) where
+ rnf ZX = ()
+ rnf (x ::% l) = rnf x `seq` rnf l
+
+data UnconsListXRes f sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
+listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
+listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
+listxUncons ZX = Nothing
+
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqType ZX ZX = Just Refl
+listxEqType (n ::% sh) (m ::% sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listxEqType sh sh'
+ = Just Refl
+listxEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqual ZX ZX = Just Refl
+listxEqual (n ::% sh) (m ::% sh')
+ | Just Refl <- testEquality n m
+ , n == m
+ , Just Refl <- listxEqual sh sh'
+ = Just Refl
+listxEqual _ _ = Nothing
+
+listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
+listxFmap _ ZX = ZX
+listxFmap f (x ::% xs) = f x ::% listxFmap f xs
+
+listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+listxFold _ ZX = mempty
+listxFold f (x ::% xs) = f x <> listxFold f xs
+
+listxLength :: ListX sh f -> Int
+listxLength = getSum . listxFold (\_ -> Sum 1)
+
+listxRank :: ListX sh f -> SNat (Rank sh)
+listxRank ZX = SNat
+listxRank (_ ::% l) | SNat <- listxRank l = SNat
+
+listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
+listxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListX sh' f -> ShowS
+ go _ ZX = id
+ go prefix (x ::% xs) = showString prefix . f x . go "," xs
+
+listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i)
+listxFromList topssh topl = go topssh topl
+ where
+ go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
+ go ZKX [] = ZX
+ go (_ :!% sh) (i : is) = Const i ::% go sh is
+ go _ _ = error $ "listxFromList: Mismatched list length (type says "
+ ++ show (ssxLength topssh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+listxToList :: ListX sh' (Const i) -> [i]
+listxToList ZX = []
+listxToList (Const i ::% is) = i : listxToList is
+
+listxHead :: ListX (mn ': sh) f -> f mn
+listxHead (i ::% _) = i
+
+listxTail :: ListX (n : sh) i -> ListX sh i
+listxTail (_ ::% sh) = sh
+
+listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend ZX idx' = idx'
+listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
+
+listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
+listxDrop long ZX = long
+listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+
+listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
+listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
+listxInit (_ ::% ZX) = ZX
+
+listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
+listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
+listxLast (x ::% ZX) = x
+
+listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g)
+listxZip ZX ZX = ZX
+listxZip (i ::% irest) (j ::% jrest) =
+ Pair i j ::% listxZip irest jrest
+
+listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g
+ -> ListX sh h
+listxZipWith _ ZX ZX = ZX
+listxZipWith f (i ::% is) (j ::% js) =
+ f i j ::% listxZipWith f is js
+
+
+-- * Mixed indices
+
+-- | This is a newtype over 'ListX'.
+type role IxX nominal representational
+type IxX :: [Maybe Nat] -> Type -> Type
+newtype IxX sh i = IxX (ListX sh (Const i))
+ deriving (Eq, Ord, Generic)
+
+pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
+pattern ZIX = IxX ZX
+
+pattern (:.%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxX sh i -> IxX sh1 i
+pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
+ where i :.% IxX shl = IxX (Const i ::% shl)
+infixr 3 :.%
+
+{-# COMPLETE ZIX, (:.%) #-}
+
+type IIxX sh = IxX sh Int
+
+instance Show i => Show (IxX sh i) where
+ showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l
+
+instance Functor (IxX sh) where
+ fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
+
+instance Foldable (IxX sh) where
+ foldMap f (IxX l) = listxFold (f . getConst) l
+
+instance NFData i => NFData (IxX sh i)
+
+ixxLength :: IxX sh i -> Int
+ixxLength (IxX l) = listxLength l
+
+ixxRank :: IxX sh i -> SNat (Rank sh)
+ixxRank (IxX l) = listxRank l
+
+ixxZero :: StaticShX sh -> IIxX sh
+ixxZero ZKX = ZIX
+ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
+
+ixxZero' :: IShX sh -> IIxX sh
+ixxZero' ZSX = ZIX
+ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
+
+ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
+ixxFromList = coerce (listxFromList @_ @i)
+
+ixxHead :: IxX (n : sh) i -> i
+ixxHead (IxX list) = getConst (listxHead list)
+
+ixxTail :: IxX (n : sh) i -> IxX sh i
+ixxTail (IxX list) = IxX (listxTail list)
+
+ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
+ixxAppend = coerce (listxAppend @_ @(Const i))
+
+ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop = coerce (listxDrop @(Const i) @(Const i))
+
+ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
+ixxInit = coerce (listxInit @(Const i))
+
+ixxLast :: forall n sh i. IxX (n : sh) i -> i
+ixxLast = coerce (listxLast @(Const i))
+
+ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
+ixxZip ZIX ZIX = ZIX
+ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
+
+ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
+ixxZipWith _ ZIX ZIX = ZIX
+ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
+
+ixxFromLinear :: IShX sh -> Int -> IIxX sh
+ixxFromLinear = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IShX sh -> Int -> (IIxX sh, Int)
+ go ZSX i = (ZIX, i)
+ go (n :$% sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` fromSMayNat' n
+ in (locali :.% idx, upi)
+
+ixxToLinear :: IShX sh -> IIxX sh -> Int
+ixxToLinear = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IShX sh -> IIxX sh -> (Int, Int)
+ go ZSX ZIX = (0, 1)
+ go (n :$% sh) (i :.% ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, fromSMayNat' n * sz)
+
+
+-- * Mixed shapes
+
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
+deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
+deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+
+instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+ rnf (SUnknown i) = rnf i
+ rnf (SKnown x) = rnf x
+
+instance TestEquality f => TestEquality (SMayNat i f) where
+ testEquality SUnknown{} SUnknown{} = Just Refl
+ testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
+ testEquality _ _ = Nothing
+
+fromSMayNat :: (n ~ Nothing => i -> r)
+ -> (forall m. n ~ Just m => f m -> r)
+ -> SMayNat i f n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
+type family AddMaybe n m where
+ AddMaybe Nothing _ = Nothing
+ AddMaybe (Just _) Nothing = Nothing
+ AddMaybe (Just n) (Just m) = Just (n + m)
+
+smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+
+
+-- | This is a newtype over 'ListX'.
+type role ShX nominal representational
+type ShX :: [Maybe Nat] -> Type -> Type
+newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+ deriving (Eq, Ord, Generic)
+
+pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
+pattern ZSX = ShX ZX
+
+pattern (:$%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
+ where i :$% ShX shl = ShX (i ::% shl)
+infixr 3 :$%
+
+{-# COMPLETE ZSX, (:$%) #-}
+
+type IShX sh = ShX sh Int
+
+instance Show i => Show (ShX sh i) where
+ showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+
+instance Functor (ShX sh) where
+ fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
+
+instance NFData i => NFData (ShX sh i) where
+ rnf (ShX ZX) = ()
+ rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
+ rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+
+-- | This checks only whether the types are equal; unknown dimensions might
+-- still differ. This corresponds to 'testEquality', except on the penultimate
+-- type parameter.
+shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqType ZSX ZSX = Just Refl
+shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
+ | Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType _ _ = Nothing
+
+-- | This checks whether all dimensions have the same value. This is more than
+-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
+-- @some@ package (except on the penultimate type parameter).
+shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqual ZSX ZSX = Just Refl
+shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
+ | i == j
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual _ _ = Nothing
+
+shxLength :: ShX sh i -> Int
+shxLength (ShX l) = listxLength l
+
+shxRank :: ShX sh i -> SNat (Rank sh)
+shxRank (ShX l) = listxRank l
+
+-- | The number of elements in an array described by this shape.
+shxSize :: IShX sh -> Int
+shxSize ZSX = 1
+shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
+
+shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
+shxFromList topssh topl = go topssh topl
+ where
+ go :: StaticShX sh' -> [Int] -> ShX sh' Int
+ go ZKX [] = ZSX
+ go (SKnown sn :!% sh) (i : is)
+ | i == fromSNat' sn = SKnown sn :$% go sh is
+ | otherwise = error $ "shxFromList: Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is
+ go _ _ = error $ "shxFromList: Mismatched list length (type says "
+ ++ show (ssxLength topssh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+shxToList :: IShX sh -> [Int]
+shxToList ZSX = []
+shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+
+shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+
+shxHead :: ShX (n : sh) i -> SMayNat i SNat n
+shxHead (ShX list) = listxHead list
+
+shxTail :: ShX (n : sh) i -> ShX sh i
+shxTail (ShX list) = ShX (listxTail list)
+
+shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+
+shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+
+shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+
+shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
+shxInit = coerce (listxInit @(SMayNat i SNat))
+
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
+shxLast = coerce (listxLast @(SMayNat i SNat))
+
+shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
+shxTakeSSX _ = flip go
+ where
+ go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
+ go ZKX _ = ZSX
+ go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+
+shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+ -> ShX sh i -> ShX sh j -> ShX sh k
+shxZipWith _ ZSX ZSX = ZSX
+shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
+
+-- This is a weird operation, so it has a long name
+shxCompleteZeros :: StaticShX sh -> IShX sh
+shxCompleteZeros ZKX = ZSX
+shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
+shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
+
+shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp _ ZKX idx = (ZSX, idx)
+shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
+
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = \sh -> go sh id []
+ where
+ go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
+ go ZSX f = (f ZIX :)
+ go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+
+shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
+shxCast ZSX ZKX = Just ZSX
+shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
+shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
+shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
+shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
+shxCast _ _ = Nothing
+
+-- | Partial version of 'shxCast'.
+shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
+shxCast' sh ssh = case shxCast sh ssh of
+ Just sh' -> sh'
+ Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
+
+
+-- * Static mixed shapes
+
+-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+type StaticShX :: [Maybe Nat] -> Type
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+ deriving (Eq, Ord)
+
+pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
+pattern ZKX = StaticShX ZX
+
+pattern (:!%)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = StaticShX (i ::% shl)
+infixr 3 :!%
+
+{-# COMPLETE ZKX, (:!%) #-}
+
+instance Show (StaticShX sh) where
+ showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+
+instance NFData (StaticShX sh) where
+ rnf (StaticShX ZX) = ()
+ rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
+ rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
+
+instance TestEquality StaticShX where
+ testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
+
+ssxLength :: StaticShX sh -> Int
+ssxLength (StaticShX l) = listxLength l
+
+ssxRank :: StaticShX sh -> SNat (Rank sh)
+ssxRank (StaticShX l) = listxRank l
+
+-- | @ssxEqType = 'testEquality'@. Provided for consistency.
+ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+ssxEqType = testEquality
+
+ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
+ssxAppend ZKX sh' = sh'
+ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+
+ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
+ssxHead (StaticShX list) = listxHead list
+
+ssxTail :: StaticShX (n : sh) -> StaticShX sh
+ssxTail (_ :!% ssh) = ssh
+
+ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+
+ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
+ssxInit = coerce (listxInit @(SMayNat () SNat))
+
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
+ssxLast = coerce (listxLast @(SMayNat () SNat))
+
+-- | This may fail if @sh@ has @Nothing@s in it.
+ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
+ssxToShX' ZKX = Just ZSX
+ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
+ssxToShX' (SUnknown _ :!% _) = Nothing
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = SUnknown () :!% ssxReplicate n
+
+ssxIotaFrom :: Int -> StaticShX sh -> [Int]
+ssxIotaFrom _ ZKX = []
+ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+
+ssxFromShape :: IShX sh -> StaticShX sh
+ssxFromShape ZSX = ZKX
+ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+
+ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
+ssxFromSNat SZ = ZKX
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShX :: [Maybe Nat] -> Constraint
+class KnownShX sh where knownShX :: StaticShX sh
+instance KnownShX '[] where knownShX = ZKX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+
+withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
+withKnownShX k = withDict @(KnownShX sh) k
+
+
+-- * Flattening
+
+type Flatten sh = Flatten' 1 sh
+
+type family Flatten' acc sh where
+ Flatten' acc '[] = Just acc
+ Flatten' acc (Nothing : sh) = Nothing
+ Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
+
+-- This function is currently unused
+ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go acc ZKX = SKnown acc
+ go _ (SUnknown () :!% _) = SUnknown ()
+ go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
+
+shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go acc ZSX = SKnown acc
+ go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
+ go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
+
+ goUnknown :: Int -> IShX sh -> Int
+ goUnknown acc ZSX = acc
+ goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
+ goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
+
+
+-- | Very untyped: only length is checked (at runtime).
+instance KnownShX sh => IsList (ListX sh (Const i)) where
+ type Item (ListX sh (Const i)) = i
+ fromList = listxFromList (knownShX @sh)
+ toList = listxToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShX sh => IsList (IxX sh i) where
+ type Item (IxX sh i) = i
+ fromList = IxX . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and known dimensions are checked (at runtime).
+instance KnownShX sh => IsList (ShX sh Int) where
+ type Item (ShX sh Int) = Int
+ fromList = shxFromList (knownShX @sh)
+ toList = shxToList