aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs52
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs421
2 files changed, 248 insertions, 225 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 98f1241..4b119c4 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -26,7 +26,6 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -80,9 +79,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
instance Elt a => Elt (Shaped sh a) where
+ {-# INLINE mshape #-}
mshape (M_Shaped arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Shaped arr) i = Shaped (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
mindexPartial (M_Shaped arr) i =
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
@@ -97,6 +99,7 @@ instance Elt a => Elt (Shaped sh a) where
mtoListOuter (M_Shaped arr)
= coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -105,6 +108,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -113,6 +117,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -132,7 +137,7 @@ instance Elt a => Elt (Shaped sh a) where
type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
- mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+ mshapeTree (Shaped arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -142,18 +147,19 @@ instance Elt a => Elt (Shaped sh a) where
marrayStrides (M_Shaped arr) = marrayStrides arr
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWriteLinear idx (Shaped arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh2 (Shaped sh a))
@(Mixed sh2 (Mixed (MapJust sh) a))
arr)
@@ -169,6 +175,14 @@ instance Elt a => Elt (Shaped sh a) where
(coerce @(MixedVecs s sh' (Shaped sh a))
@(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
@@ -181,6 +195,10 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
@@ -242,14 +260,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s
satan2Array = liftShaped2 matan2Array
+{-# INLINE sshape #-}
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shsFromShX (mshape arr)
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
- castWith (subst1 (sym (lemMapJustCons Refl))) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
-shsFromShX (SUnknown _ :$% _) = error "impossible"
+sshape (Shaped arr) = coerce (mshape arr)
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0d90e91..c5e3202 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,10 +1,8 @@
-{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
@@ -32,173 +30,157 @@ import Control.DeepSeq (NFData(..))
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product qualified as Fun
import Data.Kind (Constraint, Type)
-import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
-import GHC.Generics (Generic)
+import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
-- * Shaped lists
--- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
--- removed in a future release.
type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+type ListS :: [Nat] -> Type -> Type
+data ListS sh i where
+ ZS :: ListS '[] i
+ (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i
+deriving instance Eq i => Eq (ListS sh i)
+deriving instance Ord i => Ord (ListS sh i)
+
infixr 3 ::$
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+deriving instance Show i => Show (ListS sh i)
#else
-instance (forall n. Show (f n)) => Show (ListS sh f) where
+instance Show i => Show (ListS sh i) where
showsPrec _ = listsShow shows
#endif
-instance (forall m. NFData (f m)) => NFData (ListS n f) where
+instance NFData i => NFData (ListS n i) where
rnf ZS = ()
rnf (x ::$ l) = rnf x `seq` rnf l
-data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+data UnconsListSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i
+listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqType ZS ZS = Just Refl
-listsEqType (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listsEqType sh sh'
- = Just Refl
-listsEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqual ZS ZS = Just Refl
-listsEqual (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listsEqual sh sh'
- = Just Refl
-listsEqual _ _ = Nothing
-
-{-# INLINE listsFmap #-}
-listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-listsFmap _ ZS = ZS
-listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-
-{-# INLINE listsFoldMap #-}
-listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFoldMap _ ZS = mempty
-listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs
-
-listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
where
- go :: String -> ListS sh' f -> ShowS
+ go :: String -> ListS sh' i -> ShowS
go _ ZS = id
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFoldMap (\_ -> Sum 1)
+instance Functor (ListS l) where
+ {-# INLINE fmap #-}
+ fmap _ ZS = ZS
+ fmap f (x ::$ xs) = f x ::$ fmap f xs
+
+instance Foldable (ListS l) where
+ {-# INLINE foldMap #-}
+ foldMap _ ZS = mempty
+ foldMap f (x ::$ xs) = f x <> foldMap f xs
+ {-# INLINE foldr #-}
+ foldr _ z ZS = z
+ foldr f z (x ::$ xs) = f x (foldr f z xs)
+ toList = listsToList
+ null ZS = False
+ null _ = True
+
+listsLength :: ListS sh i -> Int
+listsLength = length
-listsRank :: ListS sh f -> SNat (Rank sh)
+listsRank :: ListS sh i -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
-listsFromList :: ShS sh -> [i] -> ListS sh (Const i)
+listsFromList :: ShS sh -> [i] -> ListS sh i
listsFromList topsh topl = go topsh topl
where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go :: ShS sh' -> [i] -> ListS sh' i
go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go (_ :$$ sh) (i : is) = i ::$ go sh is
go _ _ = error $ "listsFromList: Mismatched list length (type says "
++ show (shsLength topsh) ++ ", list has length "
++ show (length topl) ++ ")"
+{-# INLINEABLE listsFromListS #-}
+listsFromListS :: ListS sh i0 -> [i] -> ListS sh i
+listsFromListS topl0 topl = go topl0 topl
+ where
+ go :: ListS sh i0 -> [i] -> ListS sh i
+ go ZS [] = ZS
+ go (_ ::$ l0) (i : is) = i ::$ go l0 is
+ go _ _ = error $ "listsFromListS: Mismatched list length (the model says "
+ ++ show (listsLength topl0) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
{-# INLINEABLE listsToList #-}
-listsToList :: ListS sh (Const i) -> [i]
+listsToList :: ListS sh i -> [i]
listsToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: ListS sh (Const i) -> is
+ let go :: ListS sh i -> is
go ZS = nil
- go (Const i ::$ is) = i `cons` go is
+ go (i ::$ is) = i `cons` go is
in go list)
-listsHead :: ListS (n : sh) f -> f n
+listsHead :: ListS (n : sh) i -> i
listsHead (i ::$ _) = i
-listsTail :: ListS (n : sh) f -> ListS sh f
+listsTail :: ListS (n : sh) i -> ListS sh i
listsTail (_ ::$ sh) = sh
-listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i
listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
listsInit (_ ::$ ZS) = ZS
-listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast :: ListS (n : sh) i -> i
listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
listsLast (n ::$ ZS) = n
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
+listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j)
listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js
+listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js
{-# INLINE listsZipWith #-}
-listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
- -> ListS sh h
+listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k
listsZipWith _ ZS ZS = ZS
listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js
-listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i
listsDropLenPerm PNil sh = sh
listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
+ case listsIndex i sh of
+ item -> item ::$ listsPermute is sh
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh
-listsIndex _ _ _ ZS = error "Index into empty shape"
+-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed
+listsIndex :: forall j i sh. SNat i -> ListS sh j -> j
+listsIndex SZ (n ::$ _) = n
+listsIndex (SS i) (_ ::$ sh) = listsIndex i sh
+listsIndex _ ZS = error "Index into empty shape"
-listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-- * Shaped indices
@@ -206,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe
-- | An index into a shape-typed array.
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord, Generic)
+newtype IxS sh i = IxS (ListS sh i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
@@ -216,10 +198,10 @@ pattern ZIS = IxS ZS
-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i))
+ where i :.$ IxS shl = IxS (i ::$ shl)
infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
@@ -232,25 +214,9 @@ type IIxS sh = IxS sh Int
deriving instance Show i => Show (IxS sh i)
#else
instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+ showsPrec _ (IxS l) = listsShow (\i -> shows i) l
#endif
-instance Functor (IxS sh) where
- {-# INLINE fmap #-}
- fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
-
-instance Foldable (IxS sh) where
- {-# INLINE foldMap #-}
- foldMap f (IxS l) = listsFoldMap (f . getConst) l
- {-# INLINE foldr #-}
- foldr _ z ZIS = z
- foldr f z (x :.$ xs) = f x (foldr f z xs)
- toList = ixsToList
- null ZIS = False
- null _ = True
-
-instance NFData i => NFData (IxS sh i)
-
ixsLength :: IxS sh i -> Int
ixsLength (IxS l) = listsLength l
@@ -260,16 +226,19 @@ ixsRank (IxS l) = listsRank l
ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
ixsFromList = coerce (listsFromList @_ @i)
-{-# INLINEABLE ixsToList #-}
-ixsToList :: forall sh i. IxS sh i -> [i]
-ixsToList = coerce (listsToList @_ @i)
+{-# INLINEABLE ixsFromIxS #-}
+ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS = coerce (listsFromListS @_ @i0 @i)
+
+ixsToList :: IxS sh i -> [i]
+ixsToList = Foldable.toList
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
ixsHead :: IxS (n : sh) i -> i
-ixsHead (IxS list) = getConst (listsHead list)
+ixsHead (IxS list) = listsHead list
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
@@ -278,16 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
-ixsLast (IxS list) = getConst (listsLast list)
+ixsLast (IxS list) = listsLast list
--- TODO: this takes a ShS because there are KnownNats inside IxS.
-ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
-ixsCast ZSS ZIS = ZIS
-ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
-ixsCast _ _ = error "ixsCast: ranks don't match"
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
-ixsAppend = coerce (listsAppend @_ @(Const i))
+ixsAppend = coerce (listsAppend @_ @i)
ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
@@ -299,8 +266,31 @@ ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+ixsPermutePrefix = coerce (listsPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+{-# INLINEABLE ixsFromLinear #-}
+ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i
+ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i
+
+ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' (ShS sh) = (unsafeCoerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh
+ -- TODO: switch to coerce once newtypes overhauled
-- * Shaped shapes
@@ -310,21 +300,34 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Generic)
+newtype ShS sh = ShS (ShX (MapJust sh) Int)
+ deriving (NFData)
instance Eq (ShS sh) where _ == _ = True
instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
+pattern ZSS <- ShS (matchZSX -> Just Refl)
+ where ZSS = ShS ZSX
+
+matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZSX _ = Nothing
pattern (:$$)
:: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
+pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl))
+ where i :$$ ShS shl = ShS (SKnown i :$% shl)
+
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
+shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1)
+shsUncons (ShS (SKnown x :$% sh'))
+ | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsShSRes x (ShS sh'))
+shsUncons (ShS _) = Nothing
infixr 3 :$$
@@ -334,15 +337,13 @@ infixr 3 :$$
deriving instance Show (ShS sh)
#else
instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+ showsPrec d (ShS shx) = showsPrec d shx
#endif
-instance NFData (ShS sh) where
- rnf (ShS ZS) = ()
- rnf (ShS (SNat ::$ l)) = rnf (ShS l)
-
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+ testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of
+ Nothing -> Nothing
+ Just Refl -> Just unsafeCoerceRefl
-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
@@ -350,64 +351,106 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality
shsLength :: ShS sh -> Int
-shsLength (ShS l) = listsLength l
+shsLength (ShS shx) = shxLength shx
-shsRank :: ShS sh -> SNat (Rank sh)
-shsRank (ShS l) = listsRank l
+shsRank :: forall sh. ShS sh -> SNat (Rank sh)
+shsRank (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Rank (MapJust sh) :~: Rank sh) $
+ shxRank shx
shsSize :: ShS sh -> Int
-shsSize ZSS = 1
-shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+shsSize (ShS sh) = shxSize sh
-- | This is a partial @const@ that fails when the second argument
--- doesn't match the first.
+-- doesn't match the first. We don't report the size of the list
+-- in case of errors in order not to retain the list.
+{-# INLINEABLE shsFromList #-}
shsFromList :: ShS sh -> [Int] -> ShS sh
-shsFromList topsh topl = go topsh topl `seq` topsh
+shsFromList sh0@(ShS (ShX topsh)) topl = go topsh topl `seq` sh0
where
- go :: ShS sh' -> [Int] -> ()
- go ZSS [] = ()
- go (sn :$$ sh) (i : is)
+ go :: ListH sh' Int -> [Int] -> ()
+ go ZH [] = ()
+ go ZH _ = error $ "shsFromList: List too long (type says " ++ show (listhLength topsh) ++ ")"
+ go (ConsKnown sn sh) (i : is)
| i == fromSNat' sn = go sh is
- | otherwise = error $ "shsFromList: Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "shsFromList: Mismatched list length (type says "
- ++ show (shsLength topsh) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ | otherwise = error $ "shsFromList: Value does not match typing"
+ go ConsUnknown{} _ = error "shsFromList: impossible case"
+ go _ _ = error $ "shsFromList: List too short (type says " ++ show (listhLength topsh) ++ ")"
+-- This is equivalent to but faster than @coerce shxToList@.
{-# INLINEABLE shsToList #-}
shsToList :: ShS sh -> [Int]
-shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) ->
- let go :: ShS sh -> is
- go ZSS = nil
- go (sn :$$ sh) = fromSNat' sn `cons` go sh
- in go topsh)
+shsToList (ShS (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListH sh Int -> is
+ go ZH = nil
+ go ConsUnknown{} = error "shsToList: impossible case"
+ go (ConsKnown snat rest) = fromSNat' snat `cons` go rest
+ in go l)
shsHead :: ShS (n : sh) -> SNat n
-shsHead (ShS list) = listsHead list
+shsHead (ShS shx) = case shxHead shx of
+ SKnown SNat -> SNat
-shsTail :: ShS (n : sh) -> ShS sh
-shsTail (ShS list) = ShS (listsTail list)
+shsTail :: forall n sh. ShS (n : sh) -> ShS sh
+shsTail = coerce (shxTail @_ @_ @Int)
-shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
-shsInit (ShS list) = ShS (listsInit list)
+shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
+shsInit =
+ gcastWith (unsafeCoerceRefl
+ :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $
+ coerce (shxInit @_ @_ @Int)
-shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
-shsLast (ShS list) = listsLast list
+shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $
+ case shxLast shx of
+ SKnown SNat -> SNat
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
-shsAppend = coerce (listsAppend @_ @SNat)
+shsAppend =
+ gcastWith (unsafeCoerceRefl
+ :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $
+ coerce (shxAppend @_ @_ @Int)
+
+shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen =
+ gcastWith (unsafeCoerceRefl
+ :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $
+ coerce shxTakeLen
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
+shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh)
+shsDropLen =
+ gcastWith (unsafeCoerceRefl
+ :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $
+ coerce shxDropLen
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
+shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute =
+ gcastWith (unsafeCoerceRefl
+ :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $
+ coerce shxPermute
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex i (ShS sh) =
+ gcastWith (unsafeCoerceRefl
+ :: Index i (MapJust sh) :~: Just (Index i sh)) $
+ case shxIndex @_ @_ @Int i sh of
+ SKnown SNat -> SNat
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+shsPermutePrefix perm (ShS shx)
+ {- TODO: here and elsewhere, solve the module dependency cycle and add this:
+ | Refl <- lemTakeLenMapJust perm sh
+ , Refl <- lemDropLenMapJust perm sh
+ , Refl <- lemPermuteMapJust perm sh
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -}
+ = gcastWith (unsafeCoerceRefl
+ :: Permute is (TakeLen is (MapJust sh))
+ ++ DropLen is (MapJust sh)
+ :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $
+ ShS (shxPermutePrefix perm shx)
type family Product sh where
Product '[] = 1
@@ -435,37 +478,10 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
--- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
--- This function may be removed in a future release.
-shsFromListS :: ListS sh f -> ShS sh
-shsFromListS ZS = ZSS
-shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
-
--- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
--- function may be removed in a future release.
-shsFromIxS :: IxS sh i -> ShS sh
-shsFromIxS (IxS l) = shsFromListS l
-
-shsEnum :: ShS sh -> [IIxS sh]
-shsEnum = shsEnum'
-
-{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
-shsEnum' :: Num i => ShS sh -> [IxS sh i]
-shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]]
- where
- suffixes = drop 1 (scanr (*) 1 (shsToList sh))
-
- fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i
- fromLin ZSS _ _ = ZIS
- fromLin (_ :$$ sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh'
- in fromIntegral (I# q#) :.$ fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
-
-- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
+instance KnownShS sh => IsList (ListS sh i) where
+ type Item (ListS sh i) = i
fromList = listsFromList (knownShS @sh)
toList = listsToList
@@ -480,6 +496,3 @@ instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
fromList = shsFromList (knownShS @sh)
toList = shsToList
-
-$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |])
-{-# INLINEABLE ixsFromLinear #-}