diff options
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 51 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/XArray.hs | 2 |
5 files changed, 35 insertions, 38 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index b68c8b0..861bf20 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -175,7 +175,7 @@ convert = \c x -> munScalar (go c (mscalar x)) go (ConvXR @_ @sh) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let ssx' = ssxAppend (ssxFromShX esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) in M_Ranked (M_Nest esh (mcast ssx' x)) go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) @@ -197,7 +197,7 @@ convert = \c x -> munScalar (go c (mscalar x)) = x go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh x diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 652f1c6..a6e94b6 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -396,7 +396,7 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 @@ -438,7 +438,7 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -561,7 +561,7 @@ instance Elt a => Elt (Mixed sh' a) where Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -630,14 +630,14 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh + | let sh' = shxDropSh @sh @sh' sh (mshape arr) , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -827,8 +827,8 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + let sh1 = shxDropSSX ssh sh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 8e0c274..2ee3600 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -145,9 +145,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh @@ -235,7 +235,7 @@ ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixxAppend = coerce (listxAppend @_ @(Const i)) -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i ixxDrop = coerce (listxDrop @(Const i) @(Const i)) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i @@ -424,13 +424,13 @@ shxHead (ShX list) = listxHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i @@ -439,12 +439,9 @@ shxInit = coerce (listxInit @(SMayNat i SNat)) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) shxLast = coerce (listxLast @(SMayNat i SNat)) -shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) -> ShX sh i -> ShX sh j -> ShX sh k @@ -468,17 +465,17 @@ shxEnum = \sh -> go sh id [] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" @@ -538,13 +535,13 @@ ssxHead (StaticShX list) = listxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropSSX :: forall sh sh'. StaticShX (sh ++ sh') -> StaticShX sh -> StaticShX sh' +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) -ssxDropSh :: forall sh sh' i. StaticShX (sh ++ sh') -> ShX sh i -> StaticShX sh' +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) @@ -559,9 +556,9 @@ ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 0275aad..2b0b6b5 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -186,7 +186,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh1) , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr)))) (shxFromShS sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 9776e21..bf47622 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -243,7 +243,7 @@ transpose2 ssh1 ssh2 (XArray arr) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = |