diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 26 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 16 | 
5 files changed, 31 insertions, 35 deletions
| diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index ca99b02..331d5e0 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -45,9 +45,9 @@ infixr 5 `PCons`  deriving instance Show (Perm list)  deriving instance Eq (Perm list) -permLengthSNat :: Perm list -> SNat (Rank list) -permLengthSNat PNil = SNat -permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat +permRank :: Perm list -> SNat (Rank list) +permRank PNil = SNat +permRank (_ `PCons` l) | SNat <- permRank l = SNat  permFromList :: [Int] -> (forall list. Perm list -> r) -> r  permFromList [] k = k PNil @@ -67,7 +67,7 @@ permToList' = map fromIntegral . permToList  -- then @Nothing@ is returned.  permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r  permCheckPermutation = \p k -> -  let n = permLengthSNat p +  let n = permRank p    in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of         (Just Refl, Just Refl) -> Just k         _ -> Nothing diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index e46105d..95cd4ef 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -81,9 +81,9 @@ listxFold f (x ::% xs) = f x <> listxFold f xs  listxLength :: ListX sh f -> Int  listxLength = getSum . listxFold (\_ -> Sum 1) -listxLengthSNat :: ListX sh f -> SNat (Rank sh) -listxLengthSNat ZX = SNat -listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat +listxRank :: ListX sh f -> SNat (Rank sh) +listxRank ZX = SNat +listxRank (_ ::% l) | SNat <- listxRank l = SNat  listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS  listxShow f l = showString "[" . go "" l . showString "]" @@ -265,8 +265,8 @@ instance NFData i => NFData (ShX sh i) where  shxLength :: ShX sh i -> Int  shxLength (ShX l) = listxLength l -shxLengthSNat :: ShX sh f -> SNat (Rank sh) -shxLengthSNat (ShX list) = listxLengthSNat list +shxRank :: ShX sh f -> SNat (Rank sh) +shxRank (ShX list) = listxRank list  -- | This is more than @geq@: it also checks that the integers (the unknown  -- dimensions) are the same. @@ -344,10 +344,6 @@ shxEnum = \sh -> go sh id []      go ZSX f = (f ZIX :)      go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] -shxRank :: ShX sh f -> SNat (Rank sh) -shxRank ZSX = SNat -shxRank (_ :$% sh) | SNat <- shxRank sh = SNat -  -- * Static mixed shapes diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index 20f5c7a..08295cd 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -258,7 +258,7 @@ sumInner ssh ssh' arr          go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a          go (XArray arr')            | Refl <- lemRankApp ssh ssh'F -          , let sn = listxLengthSNat (let StaticShX l = ssh in l) +          , let sn = listxRank (let StaticShX l = ssh in l)            = XArray (numEltSum1Inner sn arr')      in go $ diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 1518791..6a4db8e 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -227,7 +227,7 @@ rshape :: forall n a. Elt a => Ranked n a -> IShR n  rshape (Ranked arr) = shCvtXR' (mshape arr)  rrank :: Elt a => Ranked n a -> SNat n -rrank = shrLengthSNat . rshape +rrank = shrRank . rshape  rindex :: Elt a => Ranked n a -> IIxR n -> a  rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) @@ -235,14 +235,14 @@ rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)  rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a  rindexPartial (Ranked arr) idx =    Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) -            (castWith (subst2 (lemReplicatePlusApp (ixrLengthSNat idx) (Proxy @m) (Proxy @Nothing))) arr) +            (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)              (ixCvtRX idx))  -- | __WARNING__: All values returned from the function must have equal shape.  -- See the documentation of 'mgenerate' for more details.  rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a  rgenerate sh f -  | sn@SNat <- shrLengthSNat sh +  | sn@SNat <- shrRank sh    , Dict <- lemKnownReplicate sn    , Refl <- lemRankReplicate sn    = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) @@ -274,7 +274,7 @@ rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive  rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a  rtranspose perm arr -  | sn@SNat <- shrLengthSNat (rshape arr) +  | sn@SNat <- rrank arr    , Dict <- lemKnownReplicate sn    , length perm <= fromIntegral (natVal (Proxy @n))    = rlift sn @@ -291,7 +291,7 @@ rconcat  rappend :: forall n a. Elt a          => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a  rappend arr1 arr2 -  | sn@SNat <- shrLengthSNat (rshape arr1) +  | sn@SNat <- rrank arr1    , Dict <- lemKnownReplicate sn    , Refl <- lemReplicateSucc @(Nothing @Nat) @n    = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) @@ -302,7 +302,7 @@ rscalar x = Ranked (mscalar x)  rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)  rfromVectorP sh v -  | Dict <- lemKnownReplicate (shrLengthSNat sh) +  | Dict <- lemKnownReplicate (shrRank sh)    = Ranked (mfromVectorP (shCvtRX sh) v)  rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a @@ -352,7 +352,7 @@ rfromOrthotope sn arr  rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a  rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) -  | Refl <- lemRankReplicate (shrLengthSNat $ shCvtXR' sh) +  | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh)    = arr  runScalar :: Elt a => Ranked 0 a -> a @@ -412,12 +412,12 @@ rrerank sn sh2 f (rtoPrimitive -> arr) =  rreplicate :: forall n m a. Elt a             => IShR n -> Ranked m a -> Ranked (n + m) a  rreplicate sh (Ranked arr) -  | Refl <- lemReplicatePlusApp (shrLengthSNat sh) (Proxy @m) (Proxy @(Nothing @Nat)) +  | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))    = Ranked (mreplicate (shCvtRX sh) arr)  rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)  rreplicateScalP sh x -  | Dict <- lemKnownReplicate (shrLengthSNat sh) +  | Dict <- lemKnownReplicate (shrRank sh)    = Ranked (mreplicateScalP (shCvtRX sh) x)  rreplicateScal :: forall n a. PrimElt a @@ -427,13 +427,13 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)  rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a  rslice i n arr    | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = rlift (shrLengthSNat (rshape arr)) +  = rlift (rrank arr)            (\_ -> X.sliceU i n)            arr  rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a  rrev1 arr = -  rlift (shrLengthSNat (rshape arr)) +  rlift (rrank arr)          (\(_ :: StaticShX sh') ->            case lemReplicateSucc @(Nothing @Nat) @n of              Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) @@ -442,8 +442,8 @@ rrev1 arr =  rreshape :: forall n n' a. Elt a           => IShR n' -> Ranked n a -> Ranked n' a  rreshape sh' rarr@(Ranked arr) -  | Dict <- lemKnownReplicate (shrLengthSNat (rshape rarr)) -  , Dict <- lemKnownReplicate (shrLengthSNat sh') +  | Dict <- lemKnownReplicate (rrank rarr) +  , Dict <- lemKnownReplicate (shrRank sh')    = Ranked (mreshape (shCvtRX sh') arr)  rflatten :: Elt a => Ranked n a -> Ranked 1 a diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index 7d95f61..4fa4284 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -91,14 +91,14 @@ listrIndex SZ (x ::: _) = x  listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs  listrIndex _ ZR = error "k + 1 <= 0" -listrLengthSNat :: ListR n i -> SNat n -listrLengthSNat ZR = SNat -listrLengthSNat (_ ::: (sh :: ListR n i)) = snatSucc (listrLengthSNat sh) +listrRank :: ListR n i -> SNat n +listrRank ZR = SNat +listrRank (_ ::: (sh :: ListR n i)) = snatSucc (listrRank sh)  listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i  listrPermutePrefix = \perm sh ->    listrFromList perm $ \sperm -> -    case (listrLengthSNat sperm, listrLengthSNat sh) of +    case (listrRank sperm, listrRank sh) of        (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of          LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post          EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post @@ -168,8 +168,8 @@ ixrTail (IxR list) = IxR (listrTail list)  ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i  ixrAppend = coerce (listrAppend @_ @i) -ixrLengthSNat :: IxR n i -> SNat n -ixrLengthSNat (IxR sh) = listrLengthSNat sh +ixrRank :: IxR n i -> SNat n +ixrRank (IxR sh) = listrRank sh  ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i  ixrPermutePrefix = coerce (listrPermutePrefix @i) @@ -237,8 +237,8 @@ shrTail (ShR list) = ShR (listrTail list)  shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i  shrAppend = coerce (listrAppend @_ @i) -shrLengthSNat :: ShR n i -> SNat n -shrLengthSNat (ShR sh) = listrLengthSNat sh +shrRank :: ShR n i -> SNat n +shrRank (ShR sh) = listrRank sh  shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i  shrPermutePrefix = coerce (listrPermutePrefix @i) | 
