From 2e28993ef478ff8c1eed549010383baf51ddec90 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 18 May 2024 13:24:32 +0200 Subject: More WIP --- src/Data/Array/Nested/Internal.hs | 1053 +++++++++++++++++++------------------ 1 file changed, 553 insertions(+), 500 deletions(-) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index e7e2fd6..7d98975 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} @@ -9,6 +10,7 @@ {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -79,6 +81,7 @@ import qualified Data.Array.RankedS as S import Data.Bifunctor (first) import Data.Coerce (coerce, Coercible) import Data.Foldable (toList) +import Data.Functor.Const import Data.Kind import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy @@ -87,6 +90,8 @@ import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits +import qualified GHC.TypeNats as TypeNats +import Unsafe.Coerce import Data.Array.Mixed import qualified Data.Array.Mixed as X @@ -145,30 +150,32 @@ knownNatSucc :: KnownNat n => Dict KnownNat (n + 1) knownNatSucc = Dict --- lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) --- lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) --- where --- go :: SNat m -> StaticShX (Replicate m Nothing) --- go SZ = ZKSX --- go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown GHC_SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (natSing @n) - where - go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m - go SZ = Refl - go (SS (n :: SNat nm1)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 - , Refl <- go n - = Refl +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate SZ = Refl +lemRankReplicate (SS (n :: SNat nm1)) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 + , Refl <- lemRankReplicate n + = Refl lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh lemRankMapJust ZSS = Refl lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl -lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (natSing @n) +lemReplicatePlusApp sn _ _ = go sn where go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl @@ -177,9 +184,150 @@ lemReplicatePlusApp _ _ _ = go (natSing @n) , Refl <- go n = sym (X.lemReplicateSucc @a @(n'm1 + m)) -shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKX idx = (ZSX, idx) -shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) + +-- === NEW INDEX TYPES === -- + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Show i => Show (ListR n i) +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +unconsListR :: ListR n1 i -> Maybe (UnconsListRRes i n1) +unconsListR (i ::: sh') = Just (UnconsListRRes sh' i) +unconsListR ZR = Nothing + + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Show, Eq, Ord) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (unconsListR -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +type IIxR n = IxR n Int + + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Show, Eq, Ord) + deriving newtype (Functor, Foldable) + +type IShR n = ShR n Int + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (unconsListR -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: (ShR sh) = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + + +type role ListS nominal representational +type ListS :: [Nat] -> (Nat -> Type) -> Type +data ListS sh f where + ZS :: ListS '[] f + (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f +deriving instance (forall n. Show (f n)) => Show (ListS sh f) +deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +infixr 3 ::$ + +data UnconsListSRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +unconsListS :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +unconsListS (x ::$ sh') = Just (UnconsListSRes sh' x) +unconsListS ZS = Nothing + +fmapListS :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g +fmapListS _ ZS = ZS +fmapListS f (x ::$ xs) = f x ::$ fmapListS f xs + +foldListS :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +foldListS _ ZS = mempty +foldListS f (x ::$ xs) = f x <> foldListS f xs + + +-- | An index into a shape-typed array. +-- +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). Note that because the shape of a +-- shape-typed array is known statically, you can also retrieve the array shape +-- from a 'KnownShape' dictionary. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (ListS sh (Const i)) + deriving (Show, Eq, Ord) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS = IxS ZS + +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ shl <- IxS (unconsListS -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) + where i :.$ IxS shl = IxS (Const i ::$ shl) +infixr 3 :.$ + +{-# COMPLETE ZIS, (:.$) #-} + +type IIxS sh = IxS sh Int + +instance Functor (IxS sh) where + fmap f (IxS l) = IxS (fmapListS (Const . f . getConst) l) + +instance Foldable (IxS sh) where + foldMap f (IxS l) = foldListS (f . getConst) l + + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ListS sh SNat) + deriving (Show, Eq, Ord) + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS = ShS ZS + +pattern (:$$) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ shl <- ShS (unconsListS -> Just (UnconsListSRes (ShS -> shl) i)) + where i :$$ ShS shl = ShS (i ::$ shl) + +infixr 3 :$$ + +{-# COMPLETE ZSS, (:$$) #-} -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -239,7 +387,7 @@ data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -- etc. -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(StaticShX sh1) !(Mixed (sh1 ++ sh2) a) +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) @@ -276,7 +424,7 @@ type family ShapeTree a where ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) --- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. class Elt a where @@ -296,7 +444,7 @@ class Elt a where -- -- If you want a single-dimensional array from your list, map 'mscalar' -- first. - mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a + mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] @@ -316,10 +464,10 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - -- ====== PRIVATE METHODS ====== -- + mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IShX sh -> Mixed sh a + -- ====== PRIVATE METHODS ====== -- mshapeTree :: a -> ShapeTree a @@ -329,12 +477,6 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () @@ -347,14 +489,32 @@ class Elt a where mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for 'mgenerate' / 'rgenerate' / +-- 'sgenerate'. +class Elt a => KnownElt a where + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArray :: IShX sh -> Mixed sh a + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromList1 sn l@(arr1 :| _) = - let sh = SKnown sn :$% mshape arr1 + mfromList1 l@(arr1 :| _) = + let sh = SUnknown (length l) :$% mshape arr1 in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr) @@ -379,13 +539,16 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result - memptyArray sh = M_Primitive sh (X.empty sh) + mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcast ssh1 sh2 _ (M_Primitive sh1' arr) = + let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' + in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) + mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x -- TODO: this use of toVector is suboptimal @@ -404,26 +567,36 @@ deriving via Primitive Int instance Elt Int deriving via Primitive Double instance Elt Double deriving via Primitive () instance Elt () +instance Storable a => KnownElt (Primitive a) where + memptyArray sh = M_Primitive sh (X.empty sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Double instance KnownElt Double +deriving via Primitive () instance KnownElt () + -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromList1 n l = - M_Tup2 (mfromList1 n ((\(M_Tup2 x _) -> x) <$> l)) - (mfromList1 n ((\(M_Tup2 _ y) -> y) <$> l)) + mfromList1 l = + M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) + (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + mcast ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -432,48 +605,48 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsWritePartial sh i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShX :: [Maybe Nat] -> Constraint -class KnownShX sh where knownShX :: StaticShX sh -instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX -instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. -instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where +instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest ssh arr) - = fst (shAppSplit (Proxy @sh') ssh (mshape arr)) + mshape (M_Nest sh arr) + = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) i = mindexPartial arr i mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest ssh arr) i + mindexPartial (M_Nest sh arr) i | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (X.ssxDropIx ssh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - mscalar = M_Nest ZKX + mscalar = M_Nest ZSX - mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Just n : sh) (Mixed sh' a) - mfromList1 sn l@(arr :| _) = - M_Nest (SKnown sn :!% X.staticShapeFrom (mshape arr)) - (mfromList1 sn ((\(M_Nest _ a) -> a) <$> l)) + mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromList1 l@(arr :| _) = + M_Nest (SUnknown (length l) :$% mshape arr) + (mfromList1 ((\(M_Nest _ a) -> a) <$> l)) - mtoList1 (M_Nest ssh arr) = map (M_Nest (X.ssxTail ssh)) (mtoList1 arr) + mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift ssh2 f (M_Nest ssh1 arr) = M_Nest ssh2 (mlift (X.ssxAppend ssh2 ssh') f' arr) + mlift ssh2 f (M_Nest sh1 arr) = + let result = mlift (X.ssxAppend ssh2 ssh') f' arr + (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result) + in M_Nest sh2 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr))) + ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -485,9 +658,12 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 ssh3 f (M_Nest ssh1 arr1) (M_Nest _ arr2) = M_Nest ssh3 (mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2) + mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = + let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result) + in M_Nest sh3 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr1))) + ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -496,7 +672,13 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (X.ssxAppend ssh' sshT) - memptyArray sh = M_Nest (X.staticShapeFrom sh) (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) + mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcast ssh1 sh2 _ (M_Nest sh1T arr) + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T + in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) @@ -507,14 +689,6 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. @@ -525,7 +699,26 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest (X.staticShapeFrom sh) <$> mvecsFreeze (X.shAppend sh sh') vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) + + mvecsUnsafeNew sh example + | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) + where + sh' = mshape example + + mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -- | Create an array given a size and a function that computes the element at a @@ -545,7 +738,7 @@ instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. Elt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a mgenerate sh f = case X.enumShape sh of [] -> memptyArray sh firstidx : restidxs -> @@ -601,8 +794,8 @@ mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (coerce toPrimitive arr) -mfromList :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a -mfromList sn = mfromList1 sn . fmap mscalar +mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList = mfromList1 . fmap mscalar mtoList :: Elt a => Mixed '[n] a -> [a] mtoList = map munScalar . mtoList1 @@ -620,7 +813,7 @@ mconstant sh x = fromPrimitive (mconstantP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in withKnownNat n $ mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr @@ -640,10 +833,10 @@ masXArrayPrimP (M_Primitive sh arr) = (sh, arr) masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) masXArrayPrim = masXArrayPrimP . toPrimitive -mfromXArrayPrimP :: IShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP = M_Primitive +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr -mfromXArrayPrim :: PrimElt a => IShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP mliftPrim :: (Storable a, PrimElt a) @@ -703,59 +896,46 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixe newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) -{- -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. -instance (Elt a, KnownNat n) => Elt (Ranked n a) where - mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr - mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ mindexPartial arr i - mscalar (Ranked x) = M_Ranked (M_Nest x) + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - mfromList1 :: forall m sh. KnownShapeX (m : sh) - => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a) - mfromList1 l - | Dict <- lemKnownReplicate (Proxy @n) - = M_Ranked (mfromList1 (coerce l)) + mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromList1 l = M_Ranked (mfromList1 (coerce l)) mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoList1 (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) + mtoList1 (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift f (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift f arr + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 f (M_Ranked arr1) (M_Ranked arr2) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 f arr1 arr2 + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 - memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i + mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) - mshapeTree (Ranked arr) - | Refl <- lemRankReplicate (Proxy @n) - , Dict <- lemKnownReplicate (Proxy @n) - = first shCvtXR (mshapeTree arr) + mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -763,122 +943,95 @@ instance (Elt a, KnownNat n) => Elt (Ranked n a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) --} + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArray i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArray i --- | The shape of a shape-typed array given as a list of 'SNat' values. -TODO -- write ListS and implement IxS and ShS in terms of it. -TODO -- for ListR and ListS, write an uncons function like for ListX and implement the cons pattern synonyms in terms of it directly, instead of using a separate uncons function for both types. -data ShS sh where - ZSS :: ShS '[] - (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) -deriving instance Show (ShS sh) -deriving instance Eq (ShS sh) -deriving instance Ord (ShS sh) -infixr 3 :$$ + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr -{- -sshapeKnown :: ShS sh -> Dict KnownShape sh -sshapeKnown ZSS = Dict -sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) -lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) -lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKSX - go (n :$$ sh) = n :!$@ go sh +-- sshapeKnown :: ShS sh -> Dict KnownShape sh +-- sshapeKnown ZSS = Dict +-- sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 lemCommMapJustApp ZSS _ = Refl lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl -instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where - mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr - mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i - mscalar (Shaped x) = M_Shaped (M_Nest x) + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - mfromList1 :: forall n sh'. KnownShapeX (n : sh') - => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a) - mfromList1 l - | Dict <- lemKnownMapJust (Proxy @sh) - = M_Shaped (mfromList1 (coerce l)) + mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromList1 l = M_Shaped (mfromList1 (coerce l)) mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] mtoList1 (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr) - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift f (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift f arr + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 f (M_Shaped arr1) (M_Shaped arr2) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 f arr1 arr2 + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 - memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i + mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) - mshapeTree (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = first (shCvtXS (knownShape @sh)) (mshapeTree arr) + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -886,162 +1039,84 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS --- Utility functions to satisfy the type checker sometimes +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh -rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a -rewriteMixed Refl x = x +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArray i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArray i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) -- ====== API OF RANKED ARRAYS ====== -- -arithPromoteRanked :: forall n a. KnownNat n - => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) +arithPromoteRanked :: forall n a. PrimElt a + => (forall sh. Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a -arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce +arithPromoteRanked = coerce -arithPromoteRanked2 :: forall n a. KnownNat n - => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) +arithPromoteRanked2 :: forall n a. PrimElt a + => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a -> Ranked n a -arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce +arithPromoteRanked2 = coerce -instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where +instance (Storable a, Num a, PrimElt a) => Num (Ranked n a) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger n = case natSing @n of - SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) - _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ - \Rank non-zero, use explicit mconstant" - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int) -deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double) --} - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Show i => Show (ListR n i) -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -infixr 3 ::: - -instance Foldable (ListR n) where - foldr f z l = foldr f z (listRToList l) - -listRToList :: ListR n i -> [i] -listRToList ZR = [] -listRToList (i ::: is) = i : listRToList is - -knownListR :: ListR n i -> Dict KnownNat n -knownListR ZR = Dict -knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) - where i :.: IxR sh = IxR (i ::: sh) -{-# COMPLETE ZIR, (:.:) #-} -infixr 3 :.: - -data UnconsIxRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsIxRRes (IxR n i) i -unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1) -unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i) -unconsIxR (IxR ZR) = Nothing - -type IIxR n = IxR n Int - -knownIxR :: IxR n i -> Dict KnownNat n -knownIxR (IxR sh) = knownListR sh - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -type IShR n = ShR n Int - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) - where i :$: (ShR sh) = ShR (i ::: sh) -{-# COMPLETE ZSR, (:$:) #-} -infixr 3 :$: - -data UnconsShRRes i n1 = - forall n. n + 1 ~ n1 => UnconsShRRes (ShR n i) i -unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1) -unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i) -unconsShR (ShR ZR) = Nothing - -{- -knownShR :: ShR n i -> Dict KnownNat n -knownShR (ShR sh) = knownListR sh + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant" zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR @@ -1049,100 +1124,122 @@ zeroIxR (SS n) = 0 :.: zeroIxR n ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) ixCvtXR ZIX = ZIR -ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx -ixCvtXR (n :.? idx) = n :.: ixCvtXR idx +ixCvtXR (n :.% idx) = n :.: ixCvtXR idx + +shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n +shCvtXR' ZSX = + castWith (subst2 (unsafeCoerce Refl :: 0 :~: n)) + ZSR +shCvtXR' (n :$% (idx :: IShX sh)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = + castWith (subst2 (lem1 @sh Refl)) + (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + where + lem1 :: forall sh' n' k. + k : sh' :~: Replicate n' Nothing + -> Rank sh' + 1 :~: n' + lem1 Refl = unsafeCoerce Refl -shCvtXR :: IShX sh -> IShR (X.Rank sh) -shCvtXR ZSX = ZSR -shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx -shCvtXR (n :$? idx) = n :$: shCvtXR idx + lem2 :: k : sh :~: Replicate n Nothing + -> sh :~: Replicate (Rank sh) Nothing + lem2 Refl = unsafeCoerce Refl ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) +ixCvtRX (n :.: (idx :: IxR m Int)) = + castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) +shCvtRX (n :$: (idx :: ShR m Int)) = + castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shCvtRX idx) shapeSizeR :: IShR n -> Int shapeSizeR ZSR = 1 shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n -rshape (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = shCvtXR (mshape arr) +rshape :: forall n a. Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shCvtXR' (mshape arr) rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a +snatFromListR :: ListR n i -> SNat n +snatFromListR ZR = SNat +snatFromListR (_ ::: (l :: ListR n i)) | SNat <- snatFromListR l, Dict <- knownNatSucc @n = SNat + +snatFromIxR :: IxR n i -> SNat n +snatFromIxR (IxR sh) = snatFromListR sh + +snatFromShR :: ShR n i -> SNat n +snatFromShR (ShR sh) = snatFromListR sh + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) + (castWith (subst2 (lemReplicatePlusApp (snatFromIxR idx) (Proxy @m) (Proxy @Nothing))) arr) (ixCvtRX idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f - | Dict <- knownShR sh - , Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) + | sn@SNat <- snatFromShR sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) -- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. (KnownNat n2, Elt a) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) +rlift :: forall n1 n2 a. Elt a + => SNat n2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a -rlift f (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n2) - = Ranked (mlift f arr) +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) rsumOuter1P :: forall n a. - (Storable a, Num a, KnownNat n) + (Storable a, Num a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = Ranked - . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) - . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a) - $ arr +rsumOuter1P (Ranked (M_Primitive sh arr)) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + , _ :$% shT <- sh + = Ranked (M_Primitive shT (X.sumOuter (SUnknown () :!% ZKX) (X.staticShapeFrom shT) arr)) -rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n) +rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive -rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a -rtranspose perm - | Dict <- lemKnownReplicate (Proxy @n) +rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a +rtranspose perm arr + | sn@SNat <- snatFromShR (rshape arr) + , Dict <- lemKnownReplicate sn , length perm <= fromIntegral (natVal (Proxy @n)) - = rlift $ \(Proxy @sh') -> - X.transposeUntyped (natSing @n) (knownShapeX @sh') perm + = rlift sn + (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + arr | otherwise = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" -rappend :: forall n a. (KnownNat n, Elt a) +rappend :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend - | Dict <- lemKnownReplicate (Proxy @n) +rappend arr1 arr2 + | sn@SNat <- snatFromShR (rshape arr1) + , Dict <- lemKnownReplicate sn , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) + arr1 arr2 rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) -rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v - | Dict <- lemKnownReplicate (Proxy @n) + | Dict <- lemKnownReplicate (snatFromShR sh) = Ranked (mfromVectorP (shCvtRX sh) v) -rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a +rfromVector :: forall n a. (Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a @@ -1151,14 +1248,13 @@ rtoVectorP = coerce mtoVectorP rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromList1 l - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l)) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) rfromList :: Elt a => NonEmpty a -> Ranked 1 a -rfromList = Ranked . mfromList1 . fmap mscalar +rfromList l = Ranked (mfromList l) rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoList (Ranked arr) @@ -1171,173 +1267,130 @@ rtoList1 = map runScalar . rtoList runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) rconstantP sh x - | Dict <- lemKnownReplicate (Proxy @n) + | Dict <- lemKnownReplicate (snatFromShR sh) = Ranked (mconstantP (shCvtRX sh) x) -rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a) +rconstant :: forall n a. (Storable a, PrimElt a) => IShR n -> a -> Ranked n a rconstant sh x = coerce fromPrimitive (rconstantP sh x) -rslice :: forall n a. (KnownNat n, Elt a) => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = rlift $ \_ -> X.sliceU i n - -rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 = rlift $ \(Proxy @sh') -> - case X.lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh') + = rlift (snatFromShR (rshape arr)) + (\_ -> X.sliceU i n) + arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = + rlift (snatFromShR (rshape arr)) + (\(_ :: StaticShX sh') -> + case X.lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) + arr -rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) +rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Dict <- lemKnownReplicate (Proxy @n') +rreshape sh' rarr@(Ranked arr) + | Dict <- lemKnownReplicate (snatFromShR (rshape rarr)) + , Dict <- lemKnownReplicate (snatFromShR sh') = Ranked (mreshape (shCvtRX sh') arr) -rasXArrayPrimP :: Ranked n (Primitive a) -> XArray (Replicate n Nothing) a -rasXArrayPrimP (Ranked arr) = masXArrayPrimP arr +rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) -rasXArrayPrim :: PrimElt a => Ranked n a -> XArray (Replicate n Nothing) a -rasXArrayPrim (Ranked arr) = masXArrayPrim arr +rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) -rfromXArrayPrimP :: XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP = Ranked . mfromXArrayPrimP +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) -rfromXArrayPrim :: PrimElt a => XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim = Ranked . mfromXArrayPrim +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) -- ====== API OF SHAPED ARRAYS ====== -- -arithPromoteShaped :: forall sh a. KnownShape sh - => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) +arithPromoteShaped :: forall sh a. PrimElt a + => (forall shx. Mixed shx a -> Mixed shx a) -> Shaped sh a -> Shaped sh a -arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce +arithPromoteShaped = coerce -arithPromoteShaped2 :: forall sh a. KnownShape sh - => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) +arithPromoteShaped2 :: forall sh a. PrimElt a + => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) -> Shaped sh a -> Shaped sh a -> Shaped sh a -arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce +arithPromoteShaped2 = coerce -instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where +instance (Storable a, Num a, PrimElt a) => Num (Shaped sh a) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum - fromInteger n = sconstantP (fromInteger n) - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) -deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) --} - -type role ListS nominal representational -type ListS :: [Nat] -> Type -> Type -data ListS sh i where - ZS :: ListS '[] i - (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i -deriving instance Show i => Show (ListS sh i) -deriving instance Eq i => Eq (ListS sh i) -deriving instance Ord i => Ord (ListS sh i) -deriving instance Functor (ListS sh) -infixr 3 ::$ - -instance Foldable (ListS sh) where - foldr f z l = foldr f z (listSToList l) - -listSToList :: ListS sh i -> [i] -listSToList ZS = [] -listSToList (i ::$ is) = i : listSToList is - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i)) - where i :.$ IxS shl = IxS (i ::$ shl) -{-# COMPLETE ZIS, (:.$) #-} -infixr 3 :.$ - -data UnconsIxSRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS sh i) i -unconsIxS :: IxS sh1 i -> Maybe (UnconsIxSRes i sh1) -unconsIxS (IxS (i ::$ shl')) = Just (UnconsIxSRes (IxS shl') i) -unconsIxS (IxS ZS) = Nothing - -type IIxS sh = IxS sh Int - -data UnconsShSRes sh1 = - forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n) -unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1) -unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i) -unconsShS ZSS = Nothing + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant" -{- zeroIxS :: ShS sh -> IIxS sh zeroIxS ZSS = ZIS zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx +ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx -shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh -shCvtXS ZSS ZSX = ZSS -shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx +type family Tail l where + Tail (_ : xs) = xs + +shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh +shCvtXS' ZSX = castWith (subst1 (unsafeCoerce Refl :: '[] :~: sh)) ZSS +shCvtXS' (SKnown n :$% (idx :: IShX mjshT)) = + castWith (subst1 (lem Refl)) $ + n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerce Refl :: mjshT :~: MapJust (Tail sh))) + idx) + where + lem :: forall sh1 sh' n. + Just n : sh1 :~: MapJust sh' + -> n : Tail sh' :~: sh' + lem Refl = unsafeCoerce Refl +shCvtXS' (SUnknown _ :$% _) = error "impossible" ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh +ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh shCvtSX :: ShS sh -> IShX (MapJust sh) shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = n :$@ shCvtSX sh +shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh shapeSizeS :: ShS sh -> Int shapeSizeS ZSS = 1 shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh --- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh -sshape _ = knownShape @sh +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shCvtXS' (mshape arr) sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial (Shaped arr) idx = +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr) + (castWith (subst2 (lemCommMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) (ixCvtSX idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a -sgenerate f - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh))) +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) +{- -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -1463,7 +1516,7 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) sconstant x = coerce fromPrimitive (sconstantP @sh x) sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n = withKnownNat n $ slift $ \_ -> X.slice i n +sslice i n@SNat = slift $ \_ -> X.slice i n srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a srev1 = slift $ \_ -> X.rev1 -- cgit v1.2.3-70-g09d2