aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-18 13:24:32 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-18 13:24:32 +0200
commit2e28993ef478ff8c1eed549010383baf51ddec90 (patch)
tree7b9ee20fe2b17b5cbf7d3798f6b80d095257a24c
parent4adbbd8e2e635cc4c647be40f0dd258668dd2452 (diff)
More WIP
-rw-r--r--src/Data/Array/Mixed.hs27
-rw-r--r--src/Data/Array/Nested/Internal.hs1053
2 files changed, 580 insertions, 500 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index df506d6..33d9f56 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -10,6 +10,7 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -24,6 +25,7 @@ module Data.Array.Mixed where
import qualified Data.Array.RankedS as S
import qualified Data.Array.Ranked as ORB
+import Data.Bifunctor (first)
import Data.Coerce
import Data.Functor.Const
import Data.Kind
@@ -90,6 +92,7 @@ type family Replicate n a where
Replicate n a = a : Replicate (n - 1) a
+type role ListX nominal representational
type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
data ListX sh f where
ZX :: ListX '[] f
@@ -114,6 +117,7 @@ foldListX _ ZX = mempty
foldListX f (x ::% xs) = f x <> foldListX f xs
+type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
deriving (Show, Eq, Ord)
@@ -154,6 +158,7 @@ fromSMayNat _ g (SKnown s) = g s
fromSMayNat' :: SMayNat Int SNat n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
+type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
deriving (Show, Eq, Ord)
@@ -249,6 +254,10 @@ shTail (_ :$% sh) = sh
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
+shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
+shAppSplit _ ZKX idx = (ZSX, idx)
+shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx)
+
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
@@ -354,6 +363,24 @@ toVector (XArray arr) = S.toVector arr
scalar :: Storable a => a -> XArray '[] a
scalar = XArray . S.scalar
+eqShX :: IShX sh1 -> IShX sh2 -> Bool
+eqShX ZSX ZSX = True
+eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2
+eqShX _ _ = False
+
+-- | Will throw if the array does not have the casted-to shape.
+cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> StaticShX sh'
+ -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+cast ssh1 sh2 ssh' (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh'
+ , Refl <- lemRankApp (staticShapeFrom sh2) ssh'
+ = let arrsh :: IShX sh1
+ (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ in if eqShX arrsh sh2
+ then XArray arr
+ else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
+
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index e7e2fd6..7d98975 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
+{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
@@ -9,6 +10,7 @@
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -79,6 +81,7 @@ import qualified Data.Array.RankedS as S
import Data.Bifunctor (first)
import Data.Coerce (coerce, Coercible)
import Data.Foldable (toList)
+import Data.Functor.Const
import Data.Kind
import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
@@ -87,6 +90,8 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
+import qualified GHC.TypeNats as TypeNats
+import Unsafe.Coerce
import Data.Array.Mixed
import qualified Data.Array.Mixed as X
@@ -145,30 +150,32 @@ knownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
knownNatSucc = Dict
--- lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
--- lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
--- where
--- go :: SNat m -> StaticShX (Replicate m Nothing)
--- go SZ = ZKSX
--- go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
+lemKnownShX :: StaticShX sh -> Dict KnownShX sh
+lemKnownShX ZKX = Dict
+lemKnownShX (SKnown GHC_SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
+lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
-lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (natSing @n)
- where
- go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
- go SZ = Refl
- go (SS (n :: SNat nm1))
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
- , Refl <- go n
- = Refl
+ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
+ssxFromSNat SZ = ZKX
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+
+lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
+lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
+
+lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate SZ = Refl
+lemRankReplicate (SS (n :: SNat nm1))
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
+ , Refl <- lemRankReplicate n
+ = Refl
lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh
lemRankMapJust ZSS = Refl
lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
-lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (natSing @n)
+lemReplicatePlusApp sn _ _ = go sn
where
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
@@ -177,9 +184,150 @@ lemReplicatePlusApp _ _ _ = go (natSing @n)
, Refl <- go n
= sym (X.lemReplicateSucc @a @(n'm1 + m))
-shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
-shAppSplit _ ZKX idx = (ZSX, idx)
-shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx)
+
+-- === NEW INDEX TYPES === --
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Show i => Show (ListR n i)
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+unconsListR :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+unconsListR (i ::: sh') = Just (UnconsListRRes sh' i)
+unconsListR ZR = Nothing
+
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Show, Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (unconsListR -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+type IIxR n = IxR n Int
+
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Show, Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+type IShR n = ShR n Int
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (unconsListR -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: (ShR sh) = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+
+type role ListS nominal representational
+type ListS :: [Nat] -> (Nat -> Type) -> Type
+data ListS sh f where
+ ZS :: ListS '[] f
+ (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f
+deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+infixr 3 ::$
+
+data UnconsListSRes f sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+unconsListS :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+unconsListS (x ::$ sh') = Just (UnconsListSRes sh' x)
+unconsListS ZS = Nothing
+
+fmapListS :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
+fmapListS _ ZS = ZS
+fmapListS f (x ::$ xs) = f x ::$ fmapListS f xs
+
+foldListS :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+foldListS _ ZS = mempty
+foldListS f (x ::$ xs) = f x <> foldListS f xs
+
+
+-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (ListS sh (Const i))
+ deriving (Show, Eq, Ord)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS = IxS ZS
+
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ shl <- IxS (unconsListS -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
+ where i :.$ IxS shl = IxS (Const i ::$ shl)
+infixr 3 :.$
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+type IIxS sh = IxS sh Int
+
+instance Functor (IxS sh) where
+ fmap f (IxS l) = IxS (fmapListS (Const . f . getConst) l)
+
+instance Foldable (IxS sh) where
+ foldMap f (IxS l) = foldListS (f . getConst) l
+
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ListS sh SNat)
+ deriving (Show, Eq, Ord)
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS = ShS ZS
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ shl <- ShS (unconsListS -> Just (UnconsListSRes (ShS -> shl) i))
+ where i :$$ ShS shl = ShS (i ::$ shl)
+
+infixr 3 :$$
+
+{-# COMPLETE ZSS, (:$$) #-}
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
@@ -239,7 +387,7 @@ data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b)
deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
-- etc.
-data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(StaticShX sh1) !(Mixed (sh1 ++ sh2) a)
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a)
deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
@@ -276,7 +424,7 @@ type family ShapeTree a where
ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
--- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
+-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
class Elt a where
@@ -296,7 +444,7 @@ class Elt a where
--
-- If you want a single-dimensional array from your list, map 'mscalar'
-- first.
- mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
+ mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
mtoList1 :: Mixed (n : sh) a -> [Mixed sh a]
@@ -316,10 +464,10 @@ class Elt a where
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
- -- ====== PRIVATE METHODS ====== --
+ mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
- -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IShX sh -> Mixed sh a
+ -- ====== PRIVATE METHODS ====== --
mshapeTree :: a -> ShapeTree a
@@ -329,12 +477,6 @@ class Elt a where
mshowShapeTree :: Proxy a -> ShapeTree a -> String
- -- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents.
- mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
-
- mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
@@ -347,14 +489,32 @@ class Elt a where
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+-- | Element types for which we have evidence of the (static part of the) shape
+-- in a type class constraint. Compare the instance contexts of the instances
+-- of this class with those of 'Elt': some instances have an additional
+-- "known-shape" constraint.
+--
+-- This class is (currently) only required for 'mgenerate' / 'rgenerate' /
+-- 'sgenerate'.
+class Elt a => KnownElt a where
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IShX sh -> Mixed sh a
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents.
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
+
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromList1 sn l@(arr1 :| _) =
- let sh = SKnown sn :$% mshape arr1
+ mfromList1 l@(arr1 :| _) =
+ let sh = SUnknown (length l) :$% mshape arr1
in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr)
@@ -379,13 +539,16 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
- memptyArray sh = M_Primitive sh (X.empty sh)
+ mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
+ let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1'
+ in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr)
+
mshapeTree _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
- mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
-- TODO: this use of toVector is suboptimal
@@ -404,26 +567,36 @@ deriving via Primitive Int instance Elt Int
deriving via Primitive Double instance Elt Double
deriving via Primitive () instance Elt ()
+instance Storable a => KnownElt (Primitive a) where
+ memptyArray sh = M_Primitive sh (X.empty sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance KnownElt Int
+deriving via Primitive Double instance KnownElt Double
+deriving via Primitive () instance KnownElt ()
+
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromList1 n l =
- M_Tup2 (mfromList1 n ((\(M_Tup2 x _) -> x) <$> l))
- (mfromList1 n ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromList1 l =
+ M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l))
mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mcast ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
- mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
- mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
@@ -432,48 +605,48 @@ instance (Elt a, Elt b) => Elt (a, b) where
mvecsWritePartial sh i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
--- | Evidence for the static part of a shape. This pops up only when you are
--- polymorphic in the element type of an array.
-type KnownShX :: [Maybe Nat] -> Constraint
-class KnownShX sh where knownShX :: StaticShX sh
-instance KnownShX '[] where knownShX = ZKX
-instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
-instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-- Arrays of arrays are just arrays, but with more dimensions.
-instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
+instance Elt a => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
- mshape (M_Nest ssh arr)
- = fst (shAppSplit (Proxy @sh') ssh (mshape arr))
+ mshape (M_Nest sh arr)
+ = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) i = mindexPartial arr i
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest ssh arr) i
+ mindexPartial (M_Nest sh arr) i
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (X.ssxDropIx ssh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
- mscalar = M_Nest ZKX
+ mscalar = M_Nest ZSX
- mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Just n : sh) (Mixed sh' a)
- mfromList1 sn l@(arr :| _) =
- M_Nest (SKnown sn :!% X.staticShapeFrom (mshape arr))
- (mfromList1 sn ((\(M_Nest _ a) -> a) <$> l))
+ mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromList1 l@(arr :| _) =
+ M_Nest (SUnknown (length l) :$% mshape arr)
+ (mfromList1 ((\(M_Nest _ a) -> a) <$> l))
- mtoList1 (M_Nest ssh arr) = map (M_Nest (X.ssxTail ssh)) (mtoList1 arr)
+ mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr)
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift ssh2 f (M_Nest ssh1 arr) = M_Nest ssh2 (mlift (X.ssxAppend ssh2 ssh') f' arr)
+ mlift ssh2 f (M_Nest sh1 arr) =
+ let result = mlift (X.ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result)
+ in M_Nest sh2 result
where
- ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr)))
+ ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -485,9 +658,12 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
- mlift2 ssh3 f (M_Nest ssh1 arr1) (M_Nest _ arr2) = M_Nest ssh3 (mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2)
+ mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
+ let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result)
+ in M_Nest sh3 result
where
- ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr1)))
+ ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -496,7 +672,13 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
, Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (X.ssxAppend ssh' sshT)
- memptyArray sh = M_Nest (X.staticShapeFrom sh) (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh'))))
+ mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcast ssh1 sh2 _ (M_Nest sh1T arr)
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T
+ in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr)))))
@@ -507,14 +689,6 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh')))
- where
- sh' = mshape example
-
- mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s.
@@ -525,7 +699,26 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest (X.staticShapeFrom sh) <$> mvecsFreeze (X.shAppend sh sh') vecs
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShX :: [Maybe Nat] -> Constraint
+class KnownShX sh where knownShX :: StaticShX sh
+instance KnownShX '[] where knownShX = ZKX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+
+instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
+ memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh'))))
+
+ mvecsUnsafeNew sh example
+ | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh')))
+ where
+ sh' = mshape example
+
+ mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-- | Create an array given a size and a function that computes the element at a
@@ -545,7 +738,7 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
-mgenerate :: forall sh a. Elt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f = case X.enumShape sh of
[] -> memptyArray sh
firstidx : restidxs ->
@@ -601,8 +794,8 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (coerce toPrimitive arr)
-mfromList :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a
-mfromList sn = mfromList1 sn . fmap mscalar
+mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList = mfromList1 . fmap mscalar
mtoList :: Elt a => Mixed '[n] a -> [a]
mtoList = map munScalar . mtoList1
@@ -620,7 +813,7 @@ mconstant sh x = fromPrimitive (mconstantP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
let _ :$% sh = mshape arr
- in withKnownNat n $ mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr
+ in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr
msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr
@@ -640,10 +833,10 @@ masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
masXArrayPrim = masXArrayPrimP . toPrimitive
-mfromXArrayPrimP :: IShX sh -> XArray sh a -> Mixed sh (Primitive a)
-mfromXArrayPrimP = M_Primitive
+mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
-mfromXArrayPrim :: PrimElt a => IShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
mliftPrim :: (Storable a, PrimElt a)
@@ -703,59 +896,46 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixe
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
-{-
-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
-instance (Elt a, KnownNat n) => Elt (Ranked n a) where
- mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
- mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
+instance Elt a => Elt (Ranked n a) where
+ mshape (M_Ranked arr) = mshape arr
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
mindexPartial arr i
- mscalar (Ranked x) = M_Ranked (M_Nest x)
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
- mfromList1 :: forall m sh. KnownShapeX (m : sh)
- => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a)
- mfromList1 l
- | Dict <- lemKnownReplicate (Proxy @n)
- = M_Ranked (mfromList1 (coerce l))
+ mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromList1 l = M_Ranked (mfromList1 (coerce l))
mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoList1 (M_Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr)
+ mtoList1 (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr)
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift f (M_Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift f arr
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
- mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 f (M_Ranked arr1) (M_Ranked arr2)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 f arr1 arr2
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
- memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArray i
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArray i
+ mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
- mshapeTree (Ranked arr)
- | Refl <- lemRankReplicate (Proxy @n)
- , Dict <- lemKnownReplicate (Proxy @n)
- = first shCvtXR (mshapeTree arr)
+ mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -763,122 +943,95 @@ instance (Elt a, KnownNat n) => Elt (Ranked n a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (Proxy @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
- => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ mvecsWrite sh idx (Ranked arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
-> MixedVecs s (sh ++ sh') (Ranked n a)
-> ST s ()
- mvecsWritePartial sh idx arr vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
--}
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArray i
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArray i
--- | The shape of a shape-typed array given as a list of 'SNat' values.
-TODO -- write ListS and implement IxS and ShS in terms of it.
-TODO -- for ListR and ListS, write an uncons function like for ListX and implement the cons pattern synonyms in terms of it directly, instead of using a separate uncons function for both types.
-data ShS sh where
- ZSS :: ShS '[]
- (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)
-deriving instance Show (ShS sh)
-deriving instance Eq (ShS sh)
-deriving instance Ord (ShS sh)
-infixr 3 :$$
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
-{-
-sshapeKnown :: ShS sh -> Dict KnownShape sh
-sshapeKnown ZSS = Dict
-sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
-lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKSX
- go (n :$$ sh) = n :!$@ go sh
+-- sshapeKnown :: ShS sh -> Dict KnownShape sh
+-- sshapeKnown ZSS = Dict
+-- sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict
lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2
-> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
lemCommMapJustApp ZSS _ = Refl
lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl
-instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
- mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
- mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
- mscalar (Shaped x) = M_Shaped (M_Nest x)
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromList1 :: forall n sh'. KnownShapeX (n : sh')
- => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a)
- mfromList1 l
- | Dict <- lemKnownMapJust (Proxy @sh)
- = M_Shaped (mfromList1 (coerce l))
+ mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromList1 l = M_Shaped (mfromList1 (coerce l))
mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
mtoList1 (M_Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
= coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr)
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift f (M_Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift f arr
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
- mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 f (M_Shaped arr1) (M_Shaped arr2)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 f arr1 arr2
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
- memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArray i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArray i
+ mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
- mshapeTree (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = first (shCvtXS (knownShape @sh)) (mshapeTree arr)
+ mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -886,162 +1039,84 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
- mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
-> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
-> ST s ()
- mvecsWritePartial sh idx arr vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
--- Utility functions to satisfy the type checker sometimes
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
-rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
-rewriteMixed Refl x = x
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArray i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-- ====== API OF RANKED ARRAYS ====== --
-arithPromoteRanked :: forall n a. KnownNat n
- => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
+arithPromoteRanked :: forall n a. PrimElt a
+ => (forall sh. Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a
-arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
+arithPromoteRanked = coerce
-arithPromoteRanked2 :: forall n a. KnownNat n
- => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
+arithPromoteRanked2 :: forall n a. PrimElt a
+ => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a -> Ranked n a
-arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
+arithPromoteRanked2 = coerce
-instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+instance (Storable a, Num a, PrimElt a) => Num (Ranked n a) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger n = case natSing @n of
- SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
- _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
- \Rank non-zero, use explicit mconstant"
-
--- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int)
-deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double)
--}
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Show i => Show (ListR n i)
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-infixr 3 :::
-
-instance Foldable (ListR n) where
- foldr f z l = foldr f z (listRToList l)
-
-listRToList :: ListR n i -> [i]
-listRToList ZR = []
-listRToList (i ::: is) = i : listRToList is
-
-knownListR :: ListR n i -> Dict KnownNat n
-knownListR ZR = Dict
-knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
- where i :.: IxR sh = IxR (i ::: sh)
-{-# COMPLETE ZIR, (:.:) #-}
-infixr 3 :.:
-
-data UnconsIxRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsIxRRes (IxR n i) i
-unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1)
-unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i)
-unconsIxR (IxR ZR) = Nothing
-
-type IIxR n = IxR n Int
-
-knownIxR :: IxR n i -> Dict KnownNat n
-knownIxR (IxR sh) = knownListR sh
-
-type role ShR nominal representational
-type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-type IShR n = ShR n Int
-
-pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
- where i :$: (ShR sh) = ShR (i ::: sh)
-{-# COMPLETE ZSR, (:$:) #-}
-infixr 3 :$:
-
-data UnconsShRRes i n1 =
- forall n. n + 1 ~ n1 => UnconsShRRes (ShR n i) i
-unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1)
-unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i)
-unconsShR (ShR ZR) = Nothing
-
-{-
-knownShR :: ShR n i -> Dict KnownNat n
-knownShR (ShR sh) = knownListR sh
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant"
zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
@@ -1049,100 +1124,122 @@ zeroIxR (SS n) = 0 :.: zeroIxR n
ixCvtXR :: IIxX sh -> IIxR (X.Rank sh)
ixCvtXR ZIX = ZIR
-ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx
-ixCvtXR (n :.? idx) = n :.: ixCvtXR idx
+ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
+
+shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
+shCvtXR' ZSX =
+ castWith (subst2 (unsafeCoerce Refl :: 0 :~: n))
+ ZSR
+shCvtXR' (n :$% (idx :: IShX sh))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
+ castWith (subst2 (lem1 @sh Refl))
+ (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ where
+ lem1 :: forall sh' n' k.
+ k : sh' :~: Replicate n' Nothing
+ -> Rank sh' + 1 :~: n'
+ lem1 Refl = unsafeCoerce Refl
-shCvtXR :: IShX sh -> IShR (X.Rank sh)
-shCvtXR ZSX = ZSR
-shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx
-shCvtXR (n :$? idx) = n :$: shCvtXR idx
+ lem2 :: k : sh :~: Replicate n Nothing
+ -> sh :~: Replicate (Rank sh) Nothing
+ lem2 Refl = unsafeCoerce Refl
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
+ixCvtRX (n :.: (idx :: IxR m Int)) =
+ castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
+shCvtRX (n :$: (idx :: ShR m Int)) =
+ castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shCvtRX idx)
shapeSizeR :: IShR n -> Int
shapeSizeR ZSR = 1
shapeSizeR (n :$: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n
-rshape (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemRankReplicate (Proxy @n)
- = shCvtXR (mshape arr)
+rshape :: forall n a. Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = shCvtXR' (mshape arr)
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a
+snatFromListR :: ListR n i -> SNat n
+snatFromListR ZR = SNat
+snatFromListR (_ ::: (l :: ListR n i)) | SNat <- snatFromListR l, Dict <- knownNatSucc @n = SNat
+
+snatFromIxR :: IxR n i -> SNat n
+snatFromIxR (IxR sh) = snatFromListR sh
+
+snatFromShR :: ShR n i -> SNat n
+snatFromShR (ShR sh) = snatFromListR sh
+
+rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
+ (castWith (subst2 (lemReplicatePlusApp (snatFromIxR idx) (Proxy @m) (Proxy @Nothing))) arr)
(ixCvtRX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
- | Dict <- knownShR sh
- , Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemRankReplicate (Proxy @n)
+ | sn@SNat <- snatFromShR sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+rlift :: forall n1 n2 a. Elt a
+ => SNat n2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
-rlift f (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n2)
- = Ranked (mlift f arr)
+rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownNat n)
+ (Storable a, Num a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = Ranked
- . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
- . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a)
- $ arr
+rsumOuter1P (Ranked (M_Primitive sh arr))
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ , _ :$% shT <- sh
+ = Ranked (M_Primitive shT (X.sumOuter (SUnknown () :!% ZKX) (X.staticShapeFrom shT) arr))
-rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n)
+rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
-rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
-rtranspose perm
- | Dict <- lemKnownReplicate (Proxy @n)
+rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a
+rtranspose perm arr
+ | sn@SNat <- snatFromShR (rshape arr)
+ , Dict <- lemKnownReplicate sn
, length perm <= fromIntegral (natVal (Proxy @n))
- = rlift $ \(Proxy @sh') ->
- X.transposeUntyped (natSing @n) (knownShapeX @sh') perm
+ = rlift sn
+ (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
+ arr
| otherwise
= error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
-rappend :: forall n a. (KnownNat n, Elt a)
+rappend :: forall n a. Elt a
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
-rappend
- | Dict <- lemKnownReplicate (Proxy @n)
+rappend arr1 arr2
+ | sn@SNat <- snatFromShR (rshape arr1)
+ , Dict <- lemKnownReplicate sn
, Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
+ arr1 arr2
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
-rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
- | Dict <- lemKnownReplicate (Proxy @n)
+ | Dict <- lemKnownReplicate (snatFromShR sh)
= Ranked (mfromVectorP (shCvtRX sh) v)
-rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector :: forall n a. (Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
@@ -1151,14 +1248,13 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromList1 l
- | Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
rfromList :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList = Ranked . mfromList1 . fmap mscalar
+rfromList l = Ranked (mfromList l)
rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoList (Ranked arr)
@@ -1171,173 +1267,130 @@ rtoList1 = map runScalar . rtoList
runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
-rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
rconstantP sh x
- | Dict <- lemKnownReplicate (Proxy @n)
+ | Dict <- lemKnownReplicate (snatFromShR sh)
= Ranked (mconstantP (shCvtRX sh) x)
-rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a)
+rconstant :: forall n a. (Storable a, PrimElt a)
=> IShR n -> a -> Ranked n a
rconstant sh x = coerce fromPrimitive (rconstantP sh x)
-rslice :: forall n a. (KnownNat n, Elt a) => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n
+rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n arr
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = rlift $ \_ -> X.sliceU i n
-
-rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 = rlift $ \(Proxy @sh') ->
- case X.lemReplicateSucc @(Nothing @Nat) @n of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')
+ = rlift (snatFromShR (rshape arr))
+ (\_ -> X.sliceU i n)
+ arr
+
+rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 arr =
+ rlift (snatFromShR (rshape arr))
+ (\(_ :: StaticShX sh') ->
+ case X.lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
+ arr
-rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
+rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
-rreshape sh' (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- , Dict <- lemKnownReplicate (Proxy @n')
+rreshape sh' rarr@(Ranked arr)
+ | Dict <- lemKnownReplicate (snatFromShR (rshape rarr))
+ , Dict <- lemKnownReplicate (snatFromShR sh')
= Ranked (mreshape (shCvtRX sh') arr)
-rasXArrayPrimP :: Ranked n (Primitive a) -> XArray (Replicate n Nothing) a
-rasXArrayPrimP (Ranked arr) = masXArrayPrimP arr
+rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
+rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr)
-rasXArrayPrim :: PrimElt a => Ranked n a -> XArray (Replicate n Nothing) a
-rasXArrayPrim (Ranked arr) = masXArrayPrim arr
+rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
+rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)
-rfromXArrayPrimP :: XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP = Ranked . mfromXArrayPrimP
+rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
-rfromXArrayPrim :: PrimElt a => XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim = Ranked . mfromXArrayPrim
+rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
-- ====== API OF SHAPED ARRAYS ====== --
-arithPromoteShaped :: forall sh a. KnownShape sh
- => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a)
+arithPromoteShaped :: forall sh a. PrimElt a
+ => (forall shx. Mixed shx a -> Mixed shx a)
-> Shaped sh a -> Shaped sh a
-arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+arithPromoteShaped = coerce
-arithPromoteShaped2 :: forall sh a. KnownShape sh
- => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a)
+arithPromoteShaped2 :: forall sh a. PrimElt a
+ => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a)
-> Shaped sh a -> Shaped sh a -> Shaped sh a
-arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+arithPromoteShaped2 = coerce
-instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where
+instance (Storable a, Num a, PrimElt a) => Num (Shaped sh a) where
(+) = arithPromoteShaped2 (+)
(-) = arithPromoteShaped2 (-)
(*) = arithPromoteShaped2 (*)
negate = arithPromoteShaped negate
abs = arithPromoteShaped abs
signum = arithPromoteShaped signum
- fromInteger n = sconstantP (fromInteger n)
-
--- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
-deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
--}
-
-type role ListS nominal representational
-type ListS :: [Nat] -> Type -> Type
-data ListS sh i where
- ZS :: ListS '[] i
- (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i
-deriving instance Show i => Show (ListS sh i)
-deriving instance Eq i => Eq (ListS sh i)
-deriving instance Ord i => Ord (ListS sh i)
-deriving instance Functor (ListS sh)
-infixr 3 ::$
-
-instance Foldable (ListS sh) where
- foldr f z l = foldr f z (listSToList l)
-
-listSToList :: ListS sh i -> [i]
-listSToList ZS = []
-listSToList (i ::$ is) = i : listSToList is
-
--- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\"). Note that because the shape of a
--- shape-typed array is known statically, you can also retrieve the array shape
--- from a 'KnownShape' dictionary.
-type role IxS nominal representational
-type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
-
-pattern (:.$)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i))
- where i :.$ IxS shl = IxS (i ::$ shl)
-{-# COMPLETE ZIS, (:.$) #-}
-infixr 3 :.$
-
-data UnconsIxSRes i sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS sh i) i
-unconsIxS :: IxS sh1 i -> Maybe (UnconsIxSRes i sh1)
-unconsIxS (IxS (i ::$ shl')) = Just (UnconsIxSRes (IxS shl') i)
-unconsIxS (IxS ZS) = Nothing
-
-type IIxS sh = IxS sh Int
-
-data UnconsShSRes sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n)
-unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1)
-unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i)
-unconsShS ZSS = Nothing
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant"
-{-
zeroIxS :: ShS sh -> IIxS sh
zeroIxS ZSS = ZIS
zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh
ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx
+ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
-shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh
-shCvtXS ZSS ZSX = ZSS
-shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx
+type family Tail l where
+ Tail (_ : xs) = xs
+
+shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
+shCvtXS' ZSX = castWith (subst1 (unsafeCoerce Refl :: '[] :~: sh)) ZSS
+shCvtXS' (SKnown n :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerce Refl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerce Refl
+shCvtXS' (SUnknown _ :$% _) = error "impossible"
ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh
+ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
shCvtSX :: ShS sh -> IShX (MapJust sh)
shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = n :$@ shCvtSX sh
+shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
shapeSizeS :: ShS sh -> Int
shapeSizeS ZSS = 1
shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh
--- | This does not touch the passed array, all information comes from 'KnownShape'.
-sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh
-sshape _ = knownShape @sh
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shCvtXS' (mshape arr)
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
-sindexPartial (Shaped arr) idx =
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx _ _ ZIS = ZSS
+shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+
+sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
+sindexPartial sarr@(Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr)
+ (castWith (subst2 (lemCommMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
(ixCvtSX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a
-sgenerate f
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh)))
+sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
+{-
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -1463,7 +1516,7 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
sconstant x = coerce fromPrimitive (sconstantP @sh x)
sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n = withKnownNat n $ slift $ \_ -> X.slice i n
+sslice i n@SNat = slift $ \_ -> X.slice i n
srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a
srev1 = slift $ \_ -> X.rev1