diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-06-29 12:36:03 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-29 12:40:02 +0200 | 
| commit | 260e00c3d661c21de5986ccf01d3292d3b8f7633 (patch) | |
| tree | 4e29bed0adb2724ac48e61abf5376ba613ff1c2a /src/Data/Array/Nested | |
| parent | 64404591661d3bc239804a1c17a25f81c434d852 (diff) | |
Flip some index/shape-related functions
This ensures that the argument order consistently puts the main thing
being operated on at the end, and supporting singletons at the start.
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 51 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 2 | 
4 files changed, 34 insertions, 37 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index b68c8b0..861bf20 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -175,7 +175,7 @@ convert = \c x -> munScalar (go c (mscalar x))      go (ConvXR @_ @sh) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)        = let ssx' = ssxAppend (ssxFromShX esh) -                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) +                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x))))          in M_Ranked (M_Nest esh (mcast ssx' x))      go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)      go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) @@ -197,7 +197,7 @@ convert = \c x -> munScalar (go c (mscalar x))        = x      go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)        | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') -      = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) +      = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x)      go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))        | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Nest esh x diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 652f1c6..a6e94b6 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -396,7 +396,7 @@ class Elt a => KnownElt a where  instance Storable a => Elt (Primitive a) where    mshape (M_Primitive sh _) = sh    mindex (M_Primitive _ a) i = Primitive (X.index a i) -  mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) +  mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)    mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)    mfromListOuter l@(arr1 :| _) =      let sh = SUnknown (length l) :$% mshape arr1 @@ -438,7 +438,7 @@ instance Storable a => Elt (Primitive a) where                 => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)    mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =      let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' -        sh2 = shxCast' sh1 ssh2 +        sh2 = shxCast' ssh2 sh1      in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)    mtranspose perm (M_Primitive sh arr) = @@ -561,7 +561,7 @@ instance Elt a => Elt (Mixed sh' a) where                     Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)    mindexPartial (M_Nest sh arr) i      | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) +    = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)    mscalar = M_Nest ZSX @@ -630,14 +630,14 @@ instance Elt a => Elt (Mixed sh' a) where      | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')      , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')      = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T -          sh2 = shxCast' sh1 ssh2 +          sh2 = shxCast' ssh2 sh1        in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)    mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)               => Perm is -> Mixed sh (Mixed sh' a)               -> Mixed (PermutePrefix is sh) (Mixed sh' a)    mtranspose perm (M_Nest sh arr) -    | let sh' = shxDropSh @sh @sh' (mshape arr) sh +    | let sh' = shxDropSh @sh @sh' sh (mshape arr)      , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')      , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))      , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -827,8 +827,8 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)           -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))           -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)  mrerankP ssh sh2 f (M_Primitive sh arr) = -  let sh1 = shxDropSSX sh ssh -  in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) +  let sh1 = shxDropSSX ssh sh +  in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2)                   (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)                             (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)                             arr) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 8e0c274..2ee3600 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -145,9 +145,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f  listxAppend ZX idx' = idx'  listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'  listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f  listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh @@ -235,7 +235,7 @@ ixxTail (IxX list) = IxX (listxTail list)  ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i  ixxAppend = coerce (listxAppend @_ @(Const i)) -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i  ixxDrop = coerce (listxDrop @(Const i) @(Const i))  ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i @@ -424,13 +424,13 @@ shxHead (ShX list) = listxHead list  shxTail :: ShX (n : sh) i -> ShX sh i  shxTail (ShX list) = ShX (listxTail list) -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i  shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i  shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i  shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))  shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i @@ -439,12 +439,9 @@ shxInit = coerce (listxInit @(SMayNat i SNat))  shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))  shxLast = coerce (listxLast @(SMayNat i SNat)) -shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go -  where -    go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i -    go ZKX _ = ZSX -    go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh  shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)             -> ShX sh i -> ShX sh j -> ShX sh k @@ -468,17 +465,17 @@ shxEnum = \sh -> go sh id []      go ZSX f = (f ZIX :)      go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n   :$% sh) (SKnown m    :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m    :!% ssh) | n == fromSNat' m              = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n   :$% sh) (SUnknown () :!% ssh)                                 = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh)                                 = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m    :!% ssh) (SKnown n   :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m    :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m              = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n   :$% sh)                                 = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh)                                 = (SUnknown n :$%) <$> shxCast ssh sh  shxCast _ _ = Nothing  -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of    Just sh' -> sh'    Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" @@ -538,13 +535,13 @@ ssxHead (StaticShX list) = listxHead list  ssxTail :: StaticShX (n : sh) -> StaticShX sh  ssxTail (_ :!% ssh) = ssh -ssxDropSSX :: forall sh sh'. StaticShX (sh ++ sh') -> StaticShX sh -> StaticShX sh' +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'  ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'  ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) -ssxDropSh :: forall sh sh' i. StaticShX (sh ++ sh') -> ShX sh i -> StaticShX sh' +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'  ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))  ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) @@ -559,9 +556,9 @@ ssxReplicate (SS (n :: SNat n'))    | Refl <- lemReplicateSucc @(Nothing @Nat) @n'    = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1)  ssxFromShX :: ShX sh i -> StaticShX sh  ssxFromShX ZSX = ZKX diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 0275aad..2b0b6b5 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -186,7 +186,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)  srerankP sh sh2 f sarr@(Shaped arr)    | Refl <- lemMapJustApp sh (Proxy @sh1)    , Refl <- lemMapJustApp sh (Proxy @sh2) -  = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) +  = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr))))                       (shxFromShS sh2)                       (\a -> let Shaped r = f (Shaped a) in r)                       arr) | 
