diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 53 |
2 files changed, 32 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index e5dd852..e2ec416 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -170,6 +170,14 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0d90e91..18bd2e9 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -38,7 +38,7 @@ import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -52,16 +52,14 @@ import Data.Array.Nested.Types -- * Shaped lists --- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be --- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f + (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) + infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -76,7 +74,7 @@ instance (forall m. NFData (f m)) => NFData (ListS n f) where rnf (x ::$ l) = rnf x `seq` rnf l data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing @@ -188,11 +186,11 @@ listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh + item -> item ::$ listsPermute is sh --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) +-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh) +listsIndex _ _ SZ (n ::$ _) = n listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh @@ -216,7 +214,7 @@ pattern ZIS = IxS ZS -- removed in a future release. pattern (:.$) :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) where i :.$ IxS shl = IxS (Const i ::$ shl) @@ -280,11 +278,9 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i -ixsCast ZSS ZIS = ZIS -ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx -ixsCast _ _ = error "ixsCast: ranks don't match" +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -301,6 +297,16 @@ ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear = \sh i -> go sh i 0 + where + go :: Num i => ShS sh -> IxS sh i -> i -> i + go ZSS ZIS a = a + go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i) + -- * Shaped shapes @@ -321,7 +327,7 @@ pattern ZSS = ShS ZS pattern (:$$) :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) where i :$$ ShS shl = ShS (i ::$ shl) @@ -404,7 +410,7 @@ shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) +shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh)) shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) shsPermutePrefix = coerce (listsPermutePrefix @SNat) @@ -435,17 +441,6 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict --- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. --- This function may be removed in a future release. -shsFromListS :: ListS sh f -> ShS sh -shsFromListS ZS = ZSS -shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l - --- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This --- function may be removed in a future release. -shsFromIxS :: IxS sh i -> ShS sh -shsFromIxS (IxS l) = shsFromListS l - shsEnum :: ShS sh -> [IIxS sh] shsEnum = shsEnum' |
