diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 32 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 56 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 18 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 32 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 46 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 44 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 30 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 14 | 
11 files changed, 150 insertions, 142 deletions
| diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index 7f41a06..cfb7bc6 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -64,16 +64,12 @@ lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2          -> c + 1 :~: (a + 1 + b)      lem _ _ _ Refl = Refl -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 +lemRankAppComm :: proxy sh1 -> proxy sh2                 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl  -- TODO improve this +lemRankAppComm _ _ = unsafeCoerceRefl -lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate SZ = Refl -lemRankReplicate (SS (n :: SNat nm1)) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 -  , Refl <- lemRankReplicate n -  = Refl +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl  -- ** Various type families diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 01abae3..92bc3b4 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,6 +1,7 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeAbstractions #-}  {-# LANGUAGE TypeApplications #-} @@ -16,6 +17,9 @@ module Data.Array.Nested.Convert (    rtoMixed, rcastToMixed, rcastToShaped,    stoMixed, scastToMixed, stoRanked,    mcast, mcastToShaped, mtoRanked, + +  -- * Additional index/shape casting functions +  ixrFromIxS, shrFromShS,  ) where  import Control.Category @@ -28,6 +32,7 @@ import Data.Array.Nested.Internal.Lemmas  import Data.Array.Nested.Mixed  import Data.Array.Nested.Mixed.Shape  import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape  import Data.Array.Nested.Shaped.Base  import Data.Array.Nested.Shaped.Shape @@ -37,7 +42,7 @@ mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)  mcast ssh2 arr    | Refl <- lemAppNil @sh1    , Refl <- lemAppNil @sh2 -  = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr +  = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr  mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a  mtoRanked = castCastable (CastXR CastId) @@ -74,10 +79,25 @@ stoRanked sarr@(Shaped arr)  rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a  rcastToShaped (Ranked arr) targetsh -  | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) +  | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))    , Refl <- lemRankMapJust targetsh    = mcastToShaped targetsh arr +ixrFromIxS :: IIxS sh -> IIxR (Rank sh) +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +-- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh +-- ixsFromIxR = \ix -> go ix _ +--   where +--     go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r +--     go ZIR k = k ZIS +--     go (i :.: ix) k = go ix (i :.$) + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh +  -- | The constructors that perform runtime shape checking are marked with a  -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure  -- that the shapes are already compatible. To convert between 'Ranked' and @@ -122,20 +142,20 @@ castCastable = \c x -> munScalar (go c (mscalar x))      go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)        = let x' = go c x -            ssx' = ssxAppend (ssxFromShape esh) -                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) +            ssx' = ssxAppend (ssxFromShX esh) +                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))          in M_Ranked (M_Nest esh (mcast ssx' x'))      go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))      go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') -      = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) +      = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))                                      (go c x)))      go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))      go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))      go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)      go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') -      = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x) +      = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)      lemRankAppRankEq :: Rank sh ~ Rank sh'                       => Proxy esh -> Proxy sh -> Proxy sh' diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 679e73b..373e62d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -402,7 +402,7 @@ instance Storable a => Elt (Primitive a) where    mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)    mfromListOuter l@(arr1 :| _) =      let sh = SUnknown (length l) :$% mshape arr1 -    in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) +    in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))    mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)    mlift :: forall sh1 sh2. @@ -441,16 +441,16 @@ instance Storable a => Elt (Primitive a) where    mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =      let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'          sh2 = shxCast' sh1 ssh2 -    in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) +    in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)    mtranspose perm (M_Primitive sh arr) =      M_Primitive (shxPermutePrefix perm sh) -                (X.transpose (ssxFromShape sh) perm arr) +                (X.transpose (ssxFromShX sh) perm arr)    mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)    mconcat l@(M_Primitive (_ :$% sh) _ :| _) = -    let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) -    in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result +    let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) +    in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result    mrnf (M_Primitive sh a) = rnf sh `seq` rnf a @@ -467,7 +467,7 @@ instance Storable a => Elt (Primitive a) where      :: forall sh' sh s.         IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()    mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do -    let arrsh = X.shape (ssxFromShape sh') arr +    let arrsh = X.shape (ssxFromShX sh') arr          offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))      VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) @@ -554,7 +554,7 @@ instance Elt a => Elt (Mixed sh' a) where    -- moverlongShape method, a prefix of which is mshape.    mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh    mshape (M_Nest sh arr) -    = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) +    = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))    mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a    mindex (M_Nest _ arr) i = mindexPartial arr i @@ -583,7 +583,7 @@ instance Elt a => Elt (Mixed sh' a) where          (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)      in M_Nest sh2 result      where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) +      ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))        f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b        f' sshT @@ -600,7 +600,7 @@ instance Elt a => Elt (Mixed sh' a) where          (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)      in M_Nest sh3 result      where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) +      ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))        f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b        f' sshT @@ -618,7 +618,7 @@ instance Elt a => Elt (Mixed sh' a) where          (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))      in fmap (M_Nest sh2) result      where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) +      ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))        f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)        f' sshT @@ -640,7 +640,7 @@ instance Elt a => Elt (Mixed sh' a) where               -> Mixed (PermutePrefix is sh) (Mixed sh' a)    mtranspose perm (M_Nest sh arr)      | let sh' = shxDropSh @sh @sh' (mshape arr) sh -    , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') +    , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')      , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))      , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')      , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') @@ -651,14 +651,14 @@ instance Elt a => Elt (Mixed sh' a) where    mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)    mconcat l@(M_Nest sh1 _ :| _) =      let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) -    in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result +    in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result    mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr    type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)    mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) -  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) +  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -685,7 +685,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where    mvecsUnsafeNew sh example      | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) -    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) +    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh')))      where        sh' = mshape example @@ -743,14 +743,14 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a)              => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)  msumOuter1P (M_Primitive (n :$% sh) arr) =    let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX -  in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) +  in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)  msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)             => Mixed (n : sh) a -> Mixed sh a  msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive  msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr  mappend :: forall n m sh a. Elt a          => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -758,7 +758,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2    where      sn :$% sh = mshape arr1      sm :$% _ = mshape arr2 -    ssh = ssxFromShape sh +    ssh = ssxFromShX sh      snm :: SMayNat () SNat (AddMaybe n m)      snm = case (sn, sm) of              (SUnknown{}, _) -> SUnknown () @@ -834,7 +834,7 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)  mrerankP ssh sh2 f (M_Primitive sh arr) =    let sh1 = shxDropSSX sh ssh    in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) -                 (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) +                 (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)                             (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)                             arr) @@ -849,8 +849,8 @@ mrerank ssh sh2 f (toPrimitive -> arr) =  mreplicate :: forall sh sh' a. Elt a             => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a  mreplicate sh arr = -  let ssh' = ssxFromShape (mshape arr) -  in mlift (ssxAppend (ssxFromShape sh) ssh') +  let ssh' = ssxFromShX (mshape arr) +  in mlift (ssxAppend (ssxFromShX sh) ssh')             (\(sshT :: StaticShX shT) ->                case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of                  Refl -> X.replicate sh (ssxAppend ssh' sshT)) @@ -866,18 +866,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)  mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a  mslice i n arr =    let _ :$% sh = mshape arr -  in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr +  in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr  msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr  mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr  mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a  mreshape sh' arr = -  mlift (ssxFromShape sh') -        (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') +  mlift (ssxFromShX sh') +        (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh')          arr  mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a @@ -889,12 +889,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)  -- | Throws if the array is empty.  mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) +  ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr))  -- | Throws if the array is empty.  mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) +  ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))  mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)             => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -904,7 +904,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi    = case sh1 of        _ :$% _          | sh1 == sh2 -        , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> +        , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) ->              fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))          | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"        ZSX -> error "unreachable" diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index d934873..6800f11 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -408,6 +408,12 @@ shxToList :: IShX sh -> [Int]  shxToList ZSX = []  shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX' :: StaticShX sh -> Maybe (IShX sh) +shxFromSSX' ZKX = Just ZSX +shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh +shxFromSSX' (SUnknown _ :!% _) = Nothing +  shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i  shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) @@ -540,12 +546,6 @@ ssxInit = coerce (listxInit @(SMayNat () SNat))  ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))  ssxLast = coerce (listxLast @(SMayNat () SNat)) --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShX' :: StaticShX sh -> Maybe (IShX sh) -ssxToShX' ZKX = Just ZSX -ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh -ssxToShX' (SUnknown _ :!% _) = Nothing -  ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)  ssxReplicate SZ = ZKX  ssxReplicate (SS (n :: SNat n')) @@ -556,9 +556,9 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int]  ssxIotaFrom _ ZKX = []  ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh -ssxFromShape :: IShX sh -> StaticShX sh -ssxFromShape ZSX = ZKX -ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh +ssxFromShX :: IShX sh -> StaticShX sh +ssxFromShX ZSX = ZKX +ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh  ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)  ssxFromSNat SZ = ZKX diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 14f2709..74e3893 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -50,13 +50,13 @@ rsize :: Elt a => Ranked n a -> Int  rsize = shrSize . rshape  rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) +rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)  rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a  rindexPartial (Ranked arr) idx =    Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)              (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) -            (ixCvtRX idx)) +            (ixxFromIxR idx))  -- | __WARNING__: All values returned from the function must have equal shape.  -- See the documentation of 'mgenerate' for more details. @@ -65,7 +65,7 @@ rgenerate sh f    | sn@SNat <- shrRank sh    , Dict <- lemKnownReplicate sn    , Refl <- lemRankReplicate sn -  = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) +  = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))  -- | See the documentation of 'mlift'.  rlift :: forall n1 n2 a. Elt a @@ -126,7 +126,7 @@ rscalar x = Ranked (mscalar x)  rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)  rfromVectorP sh v    | Dict <- lemKnownReplicate (shrRank sh) -  = Ranked (mfromVectorP (shCvtRX sh) v) +  = Ranked (mfromVectorP (shxFromShR sh) v)  rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a  rfromVector sh v = rfromPrimitive (rfromVectorP sh v) @@ -165,7 +165,7 @@ rfromListPrim l =  rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a  rfromListPrimLinear sh l =    let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) +  in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)  rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a  rfromListLinear sh l = rreshape sh (rfromList1 l) @@ -181,7 +181,7 @@ rfromOrthotope sn arr  rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a  rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) -  | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) +  | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh)    = arr  runScalar :: Elt a => Ranked 0 a -> a @@ -210,7 +210,7 @@ rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)  rrerankP sn sh2 f (Ranked arr)    | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))    , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) -  = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) +  = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2)                       (\a -> let Ranked r = f (Ranked a) in r)                       arr) @@ -248,12 +248,12 @@ rreplicate :: forall n m a. Elt a             => IShR n -> Ranked m a -> Ranked (n + m) a  rreplicate sh (Ranked arr)    | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) -  = Ranked (mreplicate (shCvtRX sh) arr) +  = Ranked (mreplicate (shxFromShR sh) arr)  rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)  rreplicateScalP sh x    | Dict <- lemKnownReplicate (shrRank sh) -  = Ranked (mreplicateScalP (shCvtRX sh) x) +  = Ranked (mreplicateScalP (shxFromShR sh) x)  rreplicateScal :: forall n a. PrimElt a                 => IShR n -> a -> Ranked n a @@ -279,7 +279,7 @@ rreshape :: forall n n' a. Elt a  rreshape sh' rarr@(Ranked arr)    | Dict <- lemKnownReplicate (rrank rarr)    , Dict <- lemKnownReplicate (shrRank sh') -  = Ranked (mreshape (shCvtRX sh') arr) +  = Ranked (mreshape (shxFromShR sh') arr)  rflatten :: Elt a => Ranked n a -> Ranked 1 a  rflatten (Ranked arr) = mtoRanked (mflatten arr) @@ -291,13 +291,13 @@ riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota  rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n  rminIndexPrim rarr@(Ranked arr)    | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) -  = ixCvtXR (mminIndexPrim arr) +  = ixrFromIxX (mminIndexPrim arr)  -- | Throws if the array is empty.  rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n  rmaxIndexPrim rarr@(Ranked arr)    | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) -  = ixCvtXR (mmaxIndexPrim arr) +  = ixrFromIxX (mmaxIndexPrim arr)  rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a  rdot1Inner arr1 arr2 @@ -311,16 +311,16 @@ rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a  rdot = coerce mdot  rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) +rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr)  rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) +rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr)  rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)  rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)  rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a  rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index ce7025d..b7aa00f 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -137,7 +137,7 @@ instance Elt a => Elt (Ranked n a) where    type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) -  mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) +  mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -248,7 +248,7 @@ ratan2Array = liftRanked2 matan2Array  rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) +rshape (Ranked arr) = shrFromShX2 (mshape arr)  rrank :: Elt a => Ranked n a -> SNat n  rrank = shrRank . rshape diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c18f9ee..8f54673 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -213,15 +213,15 @@ ixrZero :: SNat n -> IIxR n  ixrZero SZ = ZIR  ixrZero (SS n) = 0 :.: ixrZero n -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx +ixrFromIxX :: IIxX sh -> IIxR (Rank sh) +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = +ixxFromIxR :: IIxR n -> IIxX (Replicate n Nothing) +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m Int)) =    castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) -    (n :.% ixCvtRX idx) +    (n :.% ixxFromIxR idx)  ixrHead :: IxR (n + 1) i -> i  ixrHead (IxR list) = listrHead list @@ -278,29 +278,21 @@ instance Show i => Show (ShR n i) where  instance NFData i => NFData (ShR sh i) -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = -  castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) -    ZSR -shCvtXR' (n :$% (idx :: IShX sh)) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = -  castWith (subst2 (lem1 @sh Refl)) -    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) -  where -    lem1 :: forall sh' n' k. -            k : sh' :~: Replicate n' Nothing -         -> Rank sh' + 1 :~: n' -    lem1 Refl = unsafeCoerceRefl +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx -    lem2 :: k : sh :~: Replicate n Nothing -         -> sh :~: Replicate (Rank sh) Nothing -    lem2 Refl = unsafeCoerceRefl +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh +  | Refl <- lemRankReplicate (Proxy @n) +  = shrFromShX sh -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = +shxFromShR :: IShR n -> IShX (Replicate n Nothing) +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m Int)) =    castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) -    (SUnknown n :$% shCvtRX idx) +    (SUnknown n :$% shxFromShR idx)  -- | This checks only whether the ranks are equal, not whether the actual  -- values are. diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 97c7277..c442d6f 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -44,7 +44,7 @@ import Data.Array.Strided.Arith  semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) +semptyArray sh = Shaped (memptyArray (shxFromShS sh))  srank :: Elt a => Shaped sh a -> SNat (Rank sh)  srank = shsRank . sshape @@ -54,7 +54,7 @@ ssize :: Elt a => Shaped sh a -> Int  ssize = shsSize . sshape  sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) +sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)  shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh  shsTakeIx _ _ ZIS = ZSS @@ -64,26 +64,26 @@ sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 ->  sindexPartial sarr@(Shaped arr) idx =    Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)              (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) -            (ixCvtSX idx)) +            (ixxFromIxS idx))  -- | __WARNING__: All values returned from the function must have equal shape.  -- See the documentation of 'mgenerate' for more details.  sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))  -- | See the documentation of 'mlift'.  slift :: forall sh1 sh2 a. Elt a        => ShS sh2        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)        -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)  -- | See the documentation of 'mlift'.  slift2 :: forall sh1 sh2 sh3 a. Elt a         => ShS sh3         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)         -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)  ssumOuter1P :: forall sh n a. (Storable a, NumElt a)              => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -113,7 +113,7 @@ sscalar :: Elt a => a -> Shaped '[] a  sscalar x = Shaped (mscalar x)  sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) +sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)  sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a  sfromVector sh v = sfromPrimitive (sfromVectorP sh v) @@ -149,17 +149,17 @@ sfromListPrim sn l  sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a  sfromListPrimLinear sh l =    let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) +  in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)  sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)  stoListLinear :: Elt a => Shaped sh a -> [a]  stoListLinear (Shaped arr) = mtoListLinear arr  sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a  sfromOrthotope sh (SS.A (SG.A arr)) = -  Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) +  Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))  stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a  stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) @@ -170,7 +170,7 @@ sunScalar arr = sindex arr ZIS  snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)  snest sh arr    | Refl <- lemMapJustApp sh (Proxy @sh') -  = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) +  = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr))  sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a  sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) @@ -190,8 +190,8 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)  srerankP sh sh2 f sarr@(Shaped arr)    | Refl <- lemMapJustApp sh (Proxy @sh1)    , Refl <- lemMapJustApp sh (Proxy @sh2) -  = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) -                     (shCvtSX sh2) +  = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) +                     (shxFromShS sh2)                       (\a -> let Shaped r = f (Shaped a) in r)                       arr) @@ -205,10 +205,10 @@ srerank sh sh2 f (stoPrimitive -> arr) =  sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a  sreplicate sh (Shaped arr)    | Refl <- lemMapJustApp sh (Proxy @sh') -  = Shaped (mreplicate (shCvtSX sh) arr) +  = Shaped (mreplicate (shxFromShS sh) arr)  sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) +sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)  sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a  sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) @@ -222,7 +222,7 @@ srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a  srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr  sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) +sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)  sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a  sflatten arr = @@ -234,11 +234,11 @@ siota sn = Shaped (miota sn)  -- | Throws if the array is empty.  sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr)  -- | Throws if the array is empty.  smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)  sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)             => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a @@ -257,16 +257,16 @@ sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a  sdot = coerce mdot  stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) +stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr)  stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) +stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr)  sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr)  sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr)  sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a  sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 8f41455..ea9c24e 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -130,7 +130,7 @@ instance Elt a => Elt (Shaped sh a) where    type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) -  mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) +  mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -241,4 +241,4 @@ satan2Array = liftShaped2 matan2Array  sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) +sshape (Shaped arr) = shsFromShX (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 59a7d61..8553b56 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -231,13 +231,13 @@ ixsZero :: ShS sh -> IIxS sh  ixsZero ZSS = ZIS  ixsZero (SNat :$$ sh) = 0 :.$ ixsZero sh -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (SNat :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx +ixsFromIxX :: ShS sh -> IIxX (MapJust sh) -> IIxS sh +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (SNat :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh +ixxFromIxS :: IIxS sh -> IIxX (MapJust sh) +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh  ixsHead :: IxS (n : sh) i -> i  ixsHead (IxS list) = getConst (listsHead list) @@ -322,22 +322,22 @@ shsToList :: ShS sh -> [Int]  shsToList ZSS = []  shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = +shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =    castWith (subst1 (lem Refl)) $ -    n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) -                                 idx) +    n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) +                                   idx)    where      lem :: forall sh1 sh' n.             Just n : sh1 :~: MapJust sh'          -> n : Tail sh' :~: sh'      lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" +shsFromShX (SUnknown _ :$% _) = error "impossible" -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh  shsHead :: ShS (n : sh) -> SNat n  shsHead (ShS list) = listsHead list diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 7f78420..92e9ffb 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -76,7 +76,7 @@ cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2       -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a  cast ssh1 sh2 ssh' (XArray arr)    | Refl <- lemRankApp ssh1 ssh' -  , Refl <- lemRankApp (ssxFromShape sh2) ssh' +  , Refl <- lemRankApp (ssxFromShX sh2) ssh'    = let arrsh :: IShX sh1          (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))      in if shxToList arrsh == shxToList sh2 @@ -89,8 +89,8 @@ unScalar (XArray a) = S.unScalar a  replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a  replicate sh ssh' (XArray arr)    | Dict <- lemKnownNatRankSSX ssh' -  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') -  , Refl <- lemRankApp (ssxFromShape sh) ssh' +  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh') +  , Refl <- lemRankApp (ssxFromShX sh) ssh'    = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $              S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr)                arr) @@ -258,7 +258,7 @@ sumInner ssh ssh' arr    | Refl <- lemAppNil @sh    = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)          sh'F = shxFlatten sh' :$% ZSX -        ssh'F = ssxFromShape sh'F +        ssh'F = ssxFromShX sh'F          go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a          go (XArray arr') @@ -278,8 +278,8 @@ sumOuter ssh ssh' arr    | Refl <- lemAppNil @sh    = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)          shF = shxFlatten sh :$% ZSX -    in sumInner ssh' (ssxFromShape shF) $ -       transpose2 (ssxFromShape shF) ssh' $ +    in sumInner ssh' (ssxFromShX shF) $ +       transpose2 (ssxFromShX shF) ssh' $         reshapePartial ssh ssh' shF $           arr @@ -340,7 +340,7 @@ reshape ssh1 sh2 (XArray arr)  reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a  reshapePartial ssh1 ssh' sh2 (XArray arr)    | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') -  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') +  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh')    = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)  -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). | 
