aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs255
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs425
2 files changed, 680 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
new file mode 100644
index 0000000..879e6b5
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -0,0 +1,255 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Shaped.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+#endif
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Shaped n a) where
+ showsPrec d arr@(Shaped marr) =
+ let sh = show (shsToList (sshape arr))
+ in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Shaped sh a) where
+ rnf (Shaped arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
+
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
+ @(NonEmpty (Mixed sh2 (Shaped sh a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ mconcat l = M_Shaped (mconcat (coerce l))
+
+ mrnf (M_Shaped arr) = mrnf arr
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Shaped arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+liftShaped1 :: forall sh a b.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
+ -> Shaped sh a -> Shaped sh b
+liftShaped1 = coerce
+
+liftShaped2 :: forall sh a b c.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
+ -> Shaped sh a -> Shaped sh b -> Shaped sh c
+liftShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = liftShaped2 (+)
+ (-) = liftShaped2 (-)
+ (*) = liftShaped2 (*)
+ negate = liftShaped1 negate
+ abs = liftShaped1 abs
+ signum = liftShaped1 signum
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ recip = liftShaped1 recip
+ (/) = liftShaped2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ exp = liftShaped1 exp
+ log = liftShaped1 log
+ sqrt = liftShaped1 sqrt
+ (**) = liftShaped2 (**)
+ logBase = liftShaped2 logBase
+ sin = liftShaped1 sin
+ cos = liftShaped1 cos
+ tan = liftShaped1 tan
+ asin = liftShaped1 asin
+ acos = liftShaped1 acos
+ atan = liftShaped1 atan
+ sinh = liftShaped1 sinh
+ cosh = liftShaped1 cosh
+ tanh = liftShaped1 tanh
+ asinh = liftShaped1 asinh
+ acosh = liftShaped1 acosh
+ atanh = liftShaped1 atanh
+ log1p = liftShaped1 GHC.Float.log1p
+ expm1 = liftShaped1 GHC.Float.expm1
+ log1pexp = liftShaped1 GHC.Float.log1pexp
+ log1mexp = liftShaped1 GHC.Float.log1mexp
+
+squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+squotArray = liftShaped2 mquotArray
+sremArray = liftShaped2 mremArray
+
+satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+satan2Array = liftShaped2 matan2Array
+
+
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
new file mode 100644
index 0000000..5f9ba79
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -0,0 +1,425 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Shaped.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Array.Shape qualified as O
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Functor.Product qualified as Fun
+import Data.Kind (Constraint, Type)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Exts (withDict)
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+
+
+-- * Shaped lists
+
+-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
+-- removed in a future release.
+type role ListS nominal representational
+type ListS :: [Nat] -> (Nat -> Type) -> Type
+data ListS sh f where
+ ZS :: ListS '[] f
+ -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
+ (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+infixr 3 ::$
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+#else
+instance (forall n. Show (f n)) => Show (ListS sh f) where
+ showsPrec _ = listsShow shows
+#endif
+
+instance (forall m. NFData (f m)) => NFData (ListS n f) where
+ rnf ZS = ()
+ rnf (x ::$ l) = rnf x `seq` rnf l
+
+data UnconsListSRes f sh1 =
+ forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
+listsUncons ZS = Nothing
+
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqType ZS ZS = Just Refl
+listsEqType (n ::$ sh) (m ::$ sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listsEqType sh sh'
+ = Just Refl
+listsEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqual ZS ZS = Just Refl
+listsEqual (n ::$ sh) (m ::$ sh')
+ | Just Refl <- testEquality n m
+ , n == m
+ , Just Refl <- listsEqual sh sh'
+ = Just Refl
+listsEqual _ _ = Nothing
+
+listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
+listsFmap _ ZS = ZS
+listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
+
+listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFold _ ZS = mempty
+listsFold f (x ::$ xs) = f x <> listsFold f xs
+
+listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListS sh' f -> ShowS
+ go _ ZS = id
+ go prefix (x ::$ xs) = showString prefix . f x . go "," xs
+
+listsLength :: ListS sh f -> Int
+listsLength = getSum . listsFold (\_ -> Sum 1)
+
+listsRank :: ListS sh f -> SNat (Rank sh)
+listsRank ZS = SNat
+listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+
+listsToList :: ListS sh (Const i) -> [i]
+listsToList ZS = []
+listsToList (Const i ::$ is) = i : listsToList is
+
+listsHead :: ListS (n : sh) f -> f n
+listsHead (i ::$ _) = i
+
+listsTail :: ListS (n : sh) f -> ListS sh f
+listsTail (_ ::$ sh) = sh
+
+listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
+listsInit (_ ::$ ZS) = ZS
+
+listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
+listsLast (n ::$ ZS) = n
+
+listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend ZS idx' = idx'
+listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+
+listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
+listsZip ZS ZS = ZS
+listsZip (i ::$ is) (j ::$ js) =
+ Fun.Pair i j ::$ listsZip is js
+
+listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
+ -> ListS sh h
+listsZipWith _ ZS ZS = ZS
+listsZipWith f (i ::$ is) (j ::$ js) =
+ f i j ::$ listsZipWith f is js
+
+listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm PNil _ = ZS
+listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
+listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm PNil sh = sh
+listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
+listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
+ case listsIndex (Proxy @is') (Proxy @sh) i sh of
+ (item, SNat) -> item ::$ listsPermute is sh
+
+-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
+listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listsIndex p pT i sh
+listsIndex _ _ _ ZS = error "Index into empty shape"
+
+listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+
+-- * Shaped indices
+
+-- | An index into a shape-typed array.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (ListS sh (Const i))
+ deriving (Eq, Ord, Generic)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS = IxS ZS
+
+-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be
+-- removed in a future release.
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
+ where i :.$ IxS shl = IxS (Const i ::$ shl)
+infixr 3 :.$
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
+type IIxS sh = IxS sh Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxS sh i)
+#else
+instance Show i => Show (IxS sh i) where
+ showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+#endif
+
+instance Functor (IxS sh) where
+ fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
+
+instance Foldable (IxS sh) where
+ foldMap f (IxS l) = listsFold (f . getConst) l
+
+instance NFData i => NFData (IxS sh i)
+
+ixsLength :: IxS sh i -> Int
+ixsLength (IxS l) = listsLength l
+
+ixsRank :: IxS sh i -> SNat (Rank sh)
+ixsRank (IxS l) = listsRank l
+
+ixsZero :: ShS sh -> IIxS sh
+ixsZero ZSS = ZIS
+ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
+
+ixsHead :: IxS (n : sh) i -> i
+ixsHead (IxS list) = getConst (listsHead list)
+
+ixsTail :: IxS (n : sh) i -> IxS sh i
+ixsTail (IxS list) = IxS (listsTail list)
+
+ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
+ixsInit (IxS list) = IxS (listsInit list)
+
+ixsLast :: IxS (n : sh) i -> i
+ixsLast (IxS list) = getConst (listsLast list)
+
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
+ixsCast ZSS ZIS = ZIS
+ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
+ixsCast _ _ = error "ixsCast: ranks don't match"
+
+ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
+ixsAppend = coerce (listsAppend @_ @(Const i))
+
+ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
+ixsZip ZIS ZIS = ZIS
+ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
+
+ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
+ixsZipWith _ ZIS ZIS = ZIS
+ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
+
+ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
+ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+
+
+-- * Shaped shapes
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+--
+-- Note that because the shape of a shape-typed array is known statically, you
+-- can also retrieve the array shape from a 'KnownShS' dictionary.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ListS sh SNat)
+ deriving (Eq, Ord, Generic)
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS = ShS ZS
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
+ where i :$$ ShS shl = ShS (i ::$ shl)
+
+infixr 3 :$$
+
+{-# COMPLETE ZSS, (:$$) #-}
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (ShS sh)
+#else
+instance Show (ShS sh) where
+ showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+#endif
+
+instance NFData (ShS sh) where
+ rnf (ShS ZS) = ()
+ rnf (ShS (SNat ::$ l)) = rnf (ShS l)
+
+instance TestEquality ShS where
+ testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+
+-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
+-- equal if and only if values are equal.)
+shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
+shsEqual = testEquality
+
+shsLength :: ShS sh -> Int
+shsLength (ShS l) = listsLength l
+
+shsRank :: ShS sh -> SNat (Rank sh)
+shsRank (ShS l) = listsRank l
+
+shsSize :: ShS sh -> Int
+shsSize ZSS = 1
+shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+
+shsToList :: ShS sh -> [Int]
+shsToList ZSS = []
+shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+
+shsHead :: ShS (n : sh) -> SNat n
+shsHead (ShS list) = listsHead list
+
+shsTail :: ShS (n : sh) -> ShS sh
+shsTail (ShS list) = ShS (listsTail list)
+
+shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
+shsInit (ShS list) = ShS (listsInit list)
+
+shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS list) = listsLast list
+
+shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
+shsAppend = coerce (listsAppend @_ @SNat)
+
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen = coerce (listsTakeLenPerm @SNat)
+
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute = coerce (listsPermute @SNat)
+
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+
+shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
+shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+
+type family Product sh where
+ Product '[] = 1
+ Product (n : ns) = n * Product ns
+
+shsProduct :: ShS sh -> SNat (Product sh)
+shsProduct ZSS = SNat
+shsProduct (n :$$ sh) = n `snatMul` shsProduct sh
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+
+withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
+withKnownShS = withDict @(KnownShS sh)
+
+shsKnownShS :: ShS sh -> Dict KnownShS sh
+shsKnownShS ZSS = Dict
+shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict
+
+shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
+shsOrthotopeShape ZSS = Dict
+shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
+-- This function may be removed in a future release.
+shsFromListS :: ListS sh f -> ShS sh
+shsFromListS ZS = ZSS
+shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
+-- function may be removed in a future release.
+shsFromIxS :: IxS sh i -> ShS sh
+shsFromIxS (IxS l) = shsFromListS l
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownShS sh => IsList (ListS sh (Const i)) where
+ type Item (ListS sh (Const i)) = i
+ fromList topl = go (knownShS @sh) topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listsToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShS sh => IsList (IxS sh i) where
+ type Item (IxS sh i) = i
+ fromList = IxS . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and values are checked at runtime.
+instance KnownShS sh => IsList (ShS sh) where
+ type Item (ShS sh) = Int
+ fromList topl = ShS (go (knownShS @sh) topl)
+ where
+ go :: ShS sh' -> [Int] -> ListS sh' SNat
+ go ZSS [] = ZS
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = sn ::$ go sh is
+ | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shsToList