aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/XArray.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r--src/Data/Array/XArray.hs14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 7f78420..92e9ffb 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -76,7 +76,7 @@ cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
-> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
cast ssh1 sh2 ssh' (XArray arr)
| Refl <- lemRankApp ssh1 ssh'
- , Refl <- lemRankApp (ssxFromShape sh2) ssh'
+ , Refl <- lemRankApp (ssxFromShX sh2) ssh'
= let arrsh :: IShX sh1
(arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
in if shxToList arrsh == shxToList sh2
@@ -89,8 +89,8 @@ unScalar (XArray a) = S.unScalar a
replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
replicate sh ssh' (XArray arr)
| Dict <- lemKnownNatRankSSX ssh'
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
- , Refl <- lemRankApp (ssxFromShape sh) ssh'
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh')
+ , Refl <- lemRankApp (ssxFromShX sh) ssh'
= XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr)
arr)
@@ -258,7 +258,7 @@ sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
= let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
sh'F = shxFlatten sh' :$% ZSX
- ssh'F = ssxFromShape sh'F
+ ssh'F = ssxFromShX sh'F
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
@@ -278,8 +278,8 @@ sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
= let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
shF = shxFlatten sh :$% ZSX
- in sumInner ssh' (ssxFromShape shF) $
- transpose2 (ssxFromShape shF) ssh' $
+ in sumInner ssh' (ssxFromShX shF) $
+ transpose2 (ssxFromShX shF) ssh' $
reshapePartial ssh ssh' shF $
arr
@@ -340,7 +340,7 @@ reshape ssh1 sh2 (XArray arr)
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh')
= XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).