From 26b8f3c19cd919f5a45ef07e4ba76bae5cab35ce Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 16 May 2025 11:54:41 +0200 Subject: Shape/index function rename --- src/Data/Array/Mixed/Lemmas.hs | 14 ++++----- src/Data/Array/Nested/Convert.hs | 32 ++++++++++++++++---- src/Data/Array/Nested/Mixed.hs | 56 +++++++++++++++++------------------ src/Data/Array/Nested/Mixed/Shape.hs | 18 +++++------ src/Data/Array/Nested/Ranked.hs | 32 ++++++++++---------- src/Data/Array/Nested/Ranked/Base.hs | 4 +-- src/Data/Array/Nested/Ranked/Shape.hs | 50 +++++++++++++------------------ src/Data/Array/Nested/Shaped.hs | 44 +++++++++++++-------------- src/Data/Array/Nested/Shaped/Base.hs | 4 +-- src/Data/Array/Nested/Shaped/Shape.hs | 30 +++++++++---------- src/Data/Array/XArray.hs | 14 ++++----- 11 files changed, 153 insertions(+), 145 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index 7f41a06..cfb7bc6 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -64,16 +64,12 @@ lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 -> c + 1 :~: (a + 1 + b) lem _ _ _ Refl = Refl -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 +lemRankAppComm :: proxy sh1 -> proxy sh2 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this - -lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate SZ = Refl -lemRankReplicate (SS (n :: SNat nm1)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 - , Refl <- lemRankReplicate n - = Refl +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl -- ** Various type families diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 01abae3..92bc3b4 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} @@ -16,6 +17,9 @@ module Data.Array.Nested.Convert ( rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, + + -- * Additional index/shape casting functions + ixrFromIxS, shrFromShS, ) where import Control.Category @@ -28,6 +32,7 @@ import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape @@ -37,7 +42,7 @@ mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) mcast ssh2 arr | Refl <- lemAppNil @sh1 , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked = castCastable (CastXR CastId) @@ -74,10 +79,25 @@ stoRanked sarr@(Shaped arr) rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr +ixrFromIxS :: IIxS sh -> IIxR (Rank sh) +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +-- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh +-- ixsFromIxR = \ix -> go ix _ +-- where +-- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r +-- go ZIR k = k ZIS +-- go (i :.: ix) k = go ix (i :.$) + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + -- | The constructors that perform runtime shape checking are marked with a -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and @@ -122,20 +142,20 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let x' = go c x - ssx' = ssxAppend (ssxFromShape esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) + ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) in M_Ranked (M_Nest esh (mcast ssx' x')) go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x) + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 679e73b..373e62d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -402,7 +402,7 @@ instance Storable a => Elt (Primitive a) where mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. @@ -441,16 +441,16 @@ instance Storable a => Elt (Primitive a) where mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' sh2 = shxCast' sh1 ssh2 - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) + (X.transpose (ssxFromShX sh) perm arr) mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) mconcat l@(M_Primitive (_ :$% sh) _ :| _) = - let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) - in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result + let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result mrnf (M_Primitive sh a) = rnf sh `seq` rnf a @@ -467,7 +467,7 @@ instance Storable a => Elt (Primitive a) where :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr + let arrsh = X.shape (ssxFromShX sh') arr offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) @@ -554,7 +554,7 @@ instance Elt a => Elt (Mixed sh' a) where -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) i = mindexPartial arr i @@ -583,7 +583,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -600,7 +600,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -618,7 +618,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -640,7 +640,7 @@ instance Elt a => Elt (Mixed sh' a) where -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') @@ -651,14 +651,14 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -685,7 +685,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsUnsafeNew sh example | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh'))) where sh' = mshape example @@ -743,14 +743,14 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -758,7 +758,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 where sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 - ssh = ssxFromShape sh + ssh = ssxFromShX sh snm :: SMayNat () SNat (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () @@ -834,7 +834,7 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) mrerankP ssh sh2 f (M_Primitive sh arr) = let sh1 = shxDropSSX sh ssh in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) + (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) @@ -849,8 +849,8 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') + let ssh' = ssxFromShX (mshape arr) + in mlift (ssxAppend (ssxFromShX sh) ssh') (\(sshT :: StaticShX shT) -> case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of Refl -> X.replicate sh (ssxAppend ssh' sshT)) @@ -866,18 +866,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') + mlift (ssxFromShX sh') + (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh') arr mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a @@ -889,12 +889,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) -- | Throws if the array is empty. mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) + ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr)) -- | Throws if the array is empty. mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) + ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -904,7 +904,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi = case sh1 of _ :$% _ | sh1 == sh2 - , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> + , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) -> fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" ZSX -> error "unreachable" diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index d934873..6800f11 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -408,6 +408,12 @@ shxToList :: IShX sh -> [Int] shxToList ZSX = [] shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX' :: StaticShX sh -> Maybe (IShX sh) +shxFromSSX' ZKX = Just ZSX +shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh +shxFromSSX' (SUnknown _ :!% _) = Nothing + shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) @@ -540,12 +546,6 @@ ssxInit = coerce (listxInit @(SMayNat () SNat)) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) ssxLast = coerce (listxLast @(SMayNat () SNat)) --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShX' :: StaticShX sh -> Maybe (IShX sh) -ssxToShX' ZKX = Just ZSX -ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh -ssxToShX' (SUnknown _ :!% _) = Nothing - ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) @@ -556,9 +556,9 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKX = [] ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh -ssxFromShape :: IShX sh -> StaticShX sh -ssxFromShape ZSX = ZKX -ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh +ssxFromShX :: IShX sh -> StaticShX sh +ssxFromShX ZSX = ZKX +ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 14f2709..74e3893 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -50,13 +50,13 @@ rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) +rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) - (ixCvtRX idx)) + (ixxFromIxR idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. @@ -65,7 +65,7 @@ rgenerate sh f | sn@SNat <- shrRank sh , Dict <- lemKnownReplicate sn , Refl <- lemRankReplicate sn - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) + = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. Elt a @@ -126,7 +126,7 @@ rscalar x = Ranked (mscalar x) rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mfromVectorP (shCvtRX sh) v) + = Ranked (mfromVectorP (shxFromShR sh) v) rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = rfromPrimitive (rfromVectorP sh v) @@ -165,7 +165,7 @@ rfromListPrim l = rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a rfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) + in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a rfromListLinear sh l = rreshape sh (rfromList1 l) @@ -181,7 +181,7 @@ rfromOrthotope sn arr rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -210,7 +210,7 @@ rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) rrerankP sn sh2 f (Ranked arr) | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) + = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) (\a -> let Ranked r = f (Ranked a) in r) arr) @@ -248,12 +248,12 @@ rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked (mreplicate (shCvtRX sh) arr) + = Ranked (mreplicate (shxFromShR sh) arr) rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) rreplicateScalP sh x | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shCvtRX sh) x) + = Ranked (mreplicateScalP (shxFromShR sh) x) rreplicateScal :: forall n a. PrimElt a => IShR n -> a -> Ranked n a @@ -279,7 +279,7 @@ rreshape :: forall n n' a. Elt a rreshape sh' rarr@(Ranked arr) | Dict <- lemKnownReplicate (rrank rarr) , Dict <- lemKnownReplicate (shrRank sh') - = Ranked (mreshape (shCvtRX sh') arr) + = Ranked (mreshape (shxFromShR sh') arr) rflatten :: Elt a => Ranked n a -> Ranked 1 a rflatten (Ranked arr) = mtoRanked (mflatten arr) @@ -291,13 +291,13 @@ riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n rminIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mminIndexPrim arr) + = ixrFromIxX (mminIndexPrim arr) -- | Throws if the array is empty. rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mmaxIndexPrim arr) + = ixrFromIxX (mmaxIndexPrim arr) rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a rdot1Inner arr1 arr2 @@ -311,16 +311,16 @@ rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) +rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) +rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index ce7025d..b7aa00f 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -137,7 +137,7 @@ instance Elt a => Elt (Ranked n a) where type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) + mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -248,7 +248,7 @@ ratan2Array = liftRanked2 matan2Array rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) +rshape (Ranked arr) = shrFromShX2 (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c18f9ee..8f54673 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -213,15 +213,15 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx +ixrFromIxX :: IIxX sh -> IIxR (Rank sh) +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = +ixxFromIxR :: IIxR n -> IIxX (Replicate n Nothing) +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) + (n :.% ixxFromIxR idx) ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -278,29 +278,21 @@ instance Show i => Show (ShR n i) where instance NFData i => NFData (ShR sh i) -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh + +shxFromShR :: IShR n -> IShX (Replicate n Nothing) +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) + (SUnknown n :$% shxFromShR idx) -- | This checks only whether the ranks are equal, not whether the actual -- values are. diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 97c7277..c442d6f 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -44,7 +44,7 @@ import Data.Array.Strided.Arith semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) +semptyArray sh = Shaped (memptyArray (shxFromShS sh)) srank :: Elt a => Shaped sh a -> SNat (Rank sh) srank = shsRank . sshape @@ -54,7 +54,7 @@ ssize :: Elt a => Shaped sh a -> Int ssize = shsSize . sshape sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) +sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh shsTakeIx _ _ ZIS = ZSS @@ -64,26 +64,26 @@ sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) + (ixxFromIxS idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -113,7 +113,7 @@ sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) +sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) @@ -149,17 +149,17 @@ sfromListPrim sn l sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a sfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) + in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) stoListLinear :: Elt a => Shaped sh a -> [a] stoListLinear (Shaped arr) = mtoListLinear arr sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a sfromOrthotope sh (SS.A (SG.A arr)) = - Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) @@ -170,7 +170,7 @@ sunScalar arr = sindex arr ZIS snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) snest sh arr | Refl <- lemMapJustApp sh (Proxy @sh') - = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) + = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr)) sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) @@ -190,8 +190,8 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh1) , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) + (shxFromShS sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) @@ -205,10 +205,10 @@ srerank sh sh2 f (stoPrimitive -> arr) = sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) + = Shaped (mreplicate (shxFromShS sh) arr) sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) +sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) @@ -222,7 +222,7 @@ srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) +sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a sflatten arr = @@ -234,11 +234,11 @@ siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a @@ -257,16 +257,16 @@ sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a sdot = coerce mdot stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) +stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr) stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) +stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr) sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 8f41455..ea9c24e 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -130,7 +130,7 @@ instance Elt a => Elt (Shaped sh a) where type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) + mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -241,4 +241,4 @@ satan2Array = liftShaped2 matan2Array sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) +sshape (Shaped arr) = shsFromShX (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 59a7d61..8553b56 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -231,13 +231,13 @@ ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (SNat :$$ sh) = 0 :.$ ixsZero sh -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (SNat :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx +ixsFromIxX :: ShS sh -> IIxX (MapJust sh) -> IIxS sh +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (SNat :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh +ixxFromIxS :: IIxS sh -> IIxX (MapJust sh) +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) @@ -322,22 +322,22 @@ shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = +shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = castWith (subst1 (lem Refl)) $ - n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) where lem :: forall sh1 sh' n. Just n : sh1 :~: MapJust sh' -> n : Tail sh' :~: sh' lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" +shsFromShX (SUnknown _ :$% _) = error "impossible" -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 7f78420..92e9ffb 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -76,7 +76,7 @@ cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (ssxFromShape sh2) ssh' + , Refl <- lemRankApp (ssxFromShX sh2) ssh' = let arrsh :: IShX sh1 (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 @@ -89,8 +89,8 @@ unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') - , Refl <- lemRankApp (ssxFromShape sh) ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh') + , Refl <- lemRankApp (ssxFromShX sh) ssh' = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) arr) @@ -258,7 +258,7 @@ sumInner ssh ssh' arr | Refl <- lemAppNil @sh = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX - ssh'F = ssxFromShape sh'F + ssh'F = ssxFromShX sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') @@ -278,8 +278,8 @@ sumOuter ssh ssh' arr | Refl <- lemAppNil @sh = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX - in sumInner ssh' (ssxFromShape shF) $ - transpose2 (ssxFromShape shF) ssh' $ + in sumInner ssh' (ssxFromShX shF) $ + transpose2 (ssxFromShX shF) ssh' $ reshapePartial ssh ssh' shF $ arr @@ -340,7 +340,7 @@ reshape ssh1 sh2 (XArray arr) reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh') = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). -- cgit v1.2.3-70-g09d2