diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 52 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 421 |
2 files changed, 248 insertions, 225 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 98f1241..4b119c4 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -26,7 +26,6 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy -import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -80,9 +79,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance Elt a => Elt (Shaped sh a) where + {-# INLINE mshape #-} mshape (M_Shaped arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Shaped arr) i = Shaped (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ @@ -97,6 +99,7 @@ instance Elt a => Elt (Shaped sh a) where mtoListOuter (M_Shaped arr) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -105,6 +108,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -113,6 +117,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -132,7 +137,7 @@ instance Elt a => Elt (Shaped sh a) where type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + mshapeTree (Shaped arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -142,18 +147,19 @@ instance Elt a => Elt (Shaped sh a) where marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) @@ -169,6 +175,14 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) @@ -181,6 +195,10 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) @@ -242,14 +260,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s satan2Array = liftShaped2 matan2Array +{-# INLINE sshape #-} sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shsFromShX (mshape arr) - --- Needed already here, but re-exported in Data.Array.Nested.Convert. -shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = - castWith (subst1 (sym (lemMapJustCons Refl))) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) -shsFromShX (SUnknown _ :$% _) = error "impossible" +sshape (Shaped arr) = coerce (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0d90e91..c5e3202 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,10 +1,8 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} @@ -32,173 +30,157 @@ import Control.DeepSeq (NFData(..)) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) -import GHC.Generics (Generic) +import GHC.Exts (build, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types -- * Shaped lists --- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be --- removed in a future release. type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +type ListS :: [Nat] -> Type -> Type +data ListS sh i where + ZS :: ListS '[] i + (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i +deriving instance Eq i => Eq (ListS sh i) +deriving instance Ord i => Ord (ListS sh i) + infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListS sh f) +deriving instance Show i => Show (ListS sh i) #else -instance (forall n. Show (f n)) => Show (ListS sh f) where +instance Show i => Show (ListS sh i) where showsPrec _ = listsShow shows #endif -instance (forall m. NFData (f m)) => NFData (ListS n f) where +instance NFData i => NFData (ListS n i) where rnf ZS = () rnf (x ::$ l) = rnf x `seq` rnf l -data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +data UnconsListSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i +listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqType ZS ZS = Just Refl -listsEqType (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , Just Refl <- listsEqType sh sh' - = Just Refl -listsEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqual ZS ZS = Just Refl -listsEqual (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listsEqual sh sh' - = Just Refl -listsEqual _ _ = Nothing - -{-# INLINE listsFmap #-} -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -{-# INLINE listsFoldMap #-} -listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFoldMap _ ZS = mempty -listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListS sh' f -> ShowS + go :: String -> ListS sh' i -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs -listsLength :: ListS sh f -> Int -listsLength = getSum . listsFoldMap (\_ -> Sum 1) +instance Functor (ListS l) where + {-# INLINE fmap #-} + fmap _ ZS = ZS + fmap f (x ::$ xs) = f x ::$ fmap f xs + +instance Foldable (ListS l) where + {-# INLINE foldMap #-} + foldMap _ ZS = mempty + foldMap f (x ::$ xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZS = z + foldr f z (x ::$ xs) = f x (foldr f z xs) + toList = listsToList + null ZS = False + null _ = True + +listsLength :: ListS sh i -> Int +listsLength = length -listsRank :: ListS sh f -> SNat (Rank sh) +listsRank :: ListS sh i -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) -listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList :: ShS sh -> [i] -> ListS sh i listsFromList topsh topl = go topsh topl where - go :: ShS sh' -> [i] -> ListS sh' (Const i) + go :: ShS sh' -> [i] -> ListS sh' i go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go (_ :$$ sh) (i : is) = i ::$ go sh is go _ _ = error $ "listsFromList: Mismatched list length (type says " ++ show (shsLength topsh) ++ ", list has length " ++ show (length topl) ++ ")" +{-# INLINEABLE listsFromListS #-} +listsFromListS :: ListS sh i0 -> [i] -> ListS sh i +listsFromListS topl0 topl = go topl0 topl + where + go :: ListS sh i0 -> [i] -> ListS sh i + go ZS [] = ZS + go (_ ::$ l0) (i : is) = i ::$ go l0 is + go _ _ = error $ "listsFromListS: Mismatched list length (the model says " + ++ show (listsLength topl0) ++ ", list has length " + ++ show (length topl) ++ ")" + {-# INLINEABLE listsToList #-} -listsToList :: ListS sh (Const i) -> [i] +listsToList :: ListS sh i -> [i] listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListS sh (Const i) -> is + let go :: ListS sh i -> is go ZS = nil - go (Const i ::$ is) = i `cons` go is + go (i ::$ is) = i `cons` go is in go list) -listsHead :: ListS (n : sh) f -> f n +listsHead :: ListS (n : sh) i -> i listsHead (i ::$ _) = i -listsTail :: ListS (n : sh) f -> ListS sh f +listsTail :: ListS (n : sh) i -> ListS sh i listsTail (_ ::$ sh) = sh -listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh listsInit (_ ::$ ZS) = ZS -listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast :: ListS (n : sh) i -> i listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh listsLast (n ::$ ZS) = n -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js {-# INLINE listsZipWith #-} -listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g - -> ListS sh h +listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k listsZipWith _ ZS ZS = ZS listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh + case listsIndex i sh of + item -> item ::$ listsPermute is sh --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" +-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed +listsIndex :: forall j i sh. SNat i -> ListS sh j -> j +listsIndex SZ (n ::$ _) = n +listsIndex (SS i) (_ ::$ sh) = listsIndex i sh +listsIndex _ ZS = error "Index into empty shape" -listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- * Shaped indices @@ -206,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, Generic) +newtype IxS sh i = IxS (ListS sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -216,10 +198,10 @@ pattern ZIS = IxS ZS -- removed in a future release. pattern (:.$) :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i)) + where i :.$ IxS shl = IxS (i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} @@ -232,25 +214,9 @@ type IIxS sh = IxS sh Int deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + showsPrec _ (IxS l) = listsShow (\i -> shows i) l #endif -instance Functor (IxS sh) where - {-# INLINE fmap #-} - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - {-# INLINE foldMap #-} - foldMap f (IxS l) = listsFoldMap (f . getConst) l - {-# INLINE foldr #-} - foldr _ z ZIS = z - foldr f z (x :.$ xs) = f x (foldr f z xs) - toList = ixsToList - null ZIS = False - null _ = True - -instance NFData i => NFData (IxS sh i) - ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -260,16 +226,19 @@ ixsRank (IxS l) = listsRank l ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i ixsFromList = coerce (listsFromList @_ @i) -{-# INLINEABLE ixsToList #-} -ixsToList :: forall sh i. IxS sh i -> [i] -ixsToList = coerce (listsToList @_ @i) +{-# INLINEABLE ixsFromIxS #-} +ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS = coerce (listsFromListS @_ @i0 @i) + +ixsToList :: IxS sh i -> [i] +ixsToList = Foldable.toList ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = getConst (listsHead list) +ixsHead (IxS list) = listsHead list ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) @@ -278,16 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = getConst (listsLast list) +ixsLast (IxS list) = listsLast list --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i -ixsCast ZSS ZIS = ZIS -ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx -ixsCast _ _ = error "ixsCast: ranks don't match" +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @(Const i)) +ixsAppend = coerce (listsAppend @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS @@ -299,8 +266,31 @@ ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +ixsPermutePrefix = coerce (listsPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +{-# INLINEABLE ixsFromLinear #-} +ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i +ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i + +ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i +ixsFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' (ShS sh) = (unsafeCoerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh + -- TODO: switch to coerce once newtypes overhauled -- * Shaped shapes @@ -310,21 +300,34 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Generic) +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS +pattern ZSS <- ShS (matchZSX -> Just Refl) + where ZSS = ShS ZSX + +matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZSX _ = Nothing pattern (:$$) :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) +pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl)) + where i :$$ ShS shl = ShS (SKnown i :$% shl) + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) + | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing infixr 3 :$$ @@ -334,15 +337,13 @@ infixr 3 :$$ deriving instance Show (ShS sh) #else instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + showsPrec d (ShS shx) = showsPrec d shx #endif -instance NFData (ShS sh) where - rnf (ShS ZS) = () - rnf (ShS (SNat ::$ l)) = rnf (ShS l) - instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) @@ -350,64 +351,106 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l +shsLength (ShS shx) = shxLength shx -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Rank (MapJust sh) :~: Rank sh) $ + shxRank shx shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh +shsSize (ShS sh) = shxSize sh -- | This is a partial @const@ that fails when the second argument --- doesn't match the first. +-- doesn't match the first. We don't report the size of the list +-- in case of errors in order not to retain the list. +{-# INLINEABLE shsFromList #-} shsFromList :: ShS sh -> [Int] -> ShS sh -shsFromList topsh topl = go topsh topl `seq` topsh +shsFromList sh0@(ShS (ShX topsh)) topl = go topsh topl `seq` sh0 where - go :: ShS sh' -> [Int] -> () - go ZSS [] = () - go (sn :$$ sh) (i : is) + go :: ListH sh' Int -> [Int] -> () + go ZH [] = () + go ZH _ = error $ "shsFromList: List too long (type says " ++ show (listhLength topsh) ++ ")" + go (ConsKnown sn sh) (i : is) | i == fromSNat' sn = go sh is - | otherwise = error $ "shsFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "shsFromList: Mismatched list length (type says " - ++ show (shsLength topsh) ++ ", list has length " - ++ show (length topl) ++ ")" + | otherwise = error $ "shsFromList: Value does not match typing" + go ConsUnknown{} _ = error "shsFromList: impossible case" + go _ _ = error $ "shsFromList: List too short (type says " ++ show (listhLength topsh) ++ ")" +-- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] -shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> - let go :: ShS sh -> is - go ZSS = nil - go (sn :$$ sh) = fromSNat' sn `cons` go sh - in go topsh) +shsToList (ShS (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListH sh Int -> is + go ZH = nil + go ConsUnknown{} = error "shsToList: impossible case" + go (ConsKnown snat rest) = fromSNat' snat `cons` go rest + in go l) shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @_ @_ @Int) -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @_ @Int) + +shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce shxTakeLen -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) +shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLen = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce shxDropLen -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce shxPermute -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i (ShS sh) = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex @_ @_ @Int i sh of + SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 @@ -435,37 +478,10 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict --- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. --- This function may be removed in a future release. -shsFromListS :: ListS sh f -> ShS sh -shsFromListS ZS = ZSS -shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l - --- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This --- function may be removed in a future release. -shsFromIxS :: IxS sh i -> ShS sh -shsFromIxS (IxS l) = shsFromListS l - -shsEnum :: ShS sh -> [IIxS sh] -shsEnum = shsEnum' - -{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site -shsEnum' :: Num i => ShS sh -> [IxS sh i] -shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] - where - suffixes = drop 1 (scanr (*) 1 (shsToList sh)) - - fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i - fromLin ZSS _ _ = ZIS - fromLin (_ :$$ sh') (I# suff# : suffs) i# = - let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh' - in fromIntegral (I# q#) :.$ fromLin sh' suffs r# - fromLin _ _ _ = error "impossible" - -- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i +instance KnownShS sh => IsList (ListS sh i) where + type Item (ListS sh i) = i fromList = listsFromList (knownShS @sh) toList = listsToList @@ -480,6 +496,3 @@ instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList = shsFromList (knownShS @sh) toList = shsToList - -$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) -{-# INLINEABLE ixsFromLinear #-} |
