aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs8
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs53
2 files changed, 32 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index e5dd852..e2ec416 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -170,6 +170,14 @@ instance Elt a => Elt (Shaped sh a) where
(coerce @(MixedVecs s sh' (Shaped sh a))
@(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0d90e91..18bd2e9 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -38,7 +38,7 @@ import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
@@ -52,16 +52,14 @@ import Data.Array.Nested.Types
-- * Shaped lists
--- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
--- removed in a future release.
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+ (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+
infixr 3 ::$
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -76,7 +74,7 @@ instance (forall m. NFData (f m)) => NFData (ListS n f) where
rnf (x ::$ l) = rnf x `seq` rnf l
data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
@@ -188,11 +186,11 @@ listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
+ item -> item ::$ listsPermute is sh
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh)
+listsIndex _ _ SZ (n ::$ _) = n
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listsIndex p pT i sh
@@ -216,7 +214,7 @@ pattern ZIS = IxS ZS
-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
where i :.$ IxS shl = IxS (Const i ::$ shl)
@@ -280,11 +278,9 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
--- TODO: this takes a ShS because there are KnownNats inside IxS.
-ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
-ixsCast ZSS ZIS = ZIS
-ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
-ixsCast _ _ = error "ixsCast: ranks don't match"
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -301,6 +297,16 @@ ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear = \sh i -> go sh i 0
+ where
+ go :: Num i => ShS sh -> IxS sh i -> i -> i
+ go ZSS ZIS a = a
+ go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i)
+
-- * Shaped shapes
@@ -321,7 +327,7 @@ pattern ZSS = ShS ZS
pattern (:$$)
:: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
where i :$$ ShS shl = ShS (i ::$ shl)
@@ -404,7 +410,7 @@ shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)
shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh))
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
shsPermutePrefix = coerce (listsPermutePrefix @SNat)
@@ -435,17 +441,6 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
--- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
--- This function may be removed in a future release.
-shsFromListS :: ListS sh f -> ShS sh
-shsFromListS ZS = ZSS
-shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
-
--- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
--- function may be removed in a future release.
-shsFromIxS :: IxS sh i -> ShS sh
-shsFromIxS (IxS l) = shsFromListS l
-
shsEnum :: ShS sh -> [IIxS sh]
shsEnum = shsEnum'