diff options
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r-- | src/Data/Array/XArray.hs | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 7f78420..92e9ffb 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -76,7 +76,7 @@ cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (ssxFromShape sh2) ssh' + , Refl <- lemRankApp (ssxFromShX sh2) ssh' = let arrsh :: IShX sh1 (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 @@ -89,8 +89,8 @@ unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') - , Refl <- lemRankApp (ssxFromShape sh) ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh') + , Refl <- lemRankApp (ssxFromShX sh) ssh' = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) arr) @@ -258,7 +258,7 @@ sumInner ssh ssh' arr | Refl <- lemAppNil @sh = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX - ssh'F = ssxFromShape sh'F + ssh'F = ssxFromShX sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') @@ -278,8 +278,8 @@ sumOuter ssh ssh' arr | Refl <- lemAppNil @sh = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX - in sumInner ssh' (ssxFromShape shF) $ - transpose2 (ssxFromShape shF) ssh' $ + in sumInner ssh' (ssxFromShX shF) $ + transpose2 (ssxFromShX shF) ssh' $ reshapePartial ssh ssh' shF $ arr @@ -340,7 +340,7 @@ reshape ssh1 sh2 (XArray arr) reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh') = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). |