aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/XArray.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/XArray.hs')
-rw-r--r--src/Data/Array/Mixed/XArray.hs5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index 71bdc1f..204c1d8 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -34,6 +34,7 @@ import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
+import Data.Array.Strided.Arith
type XArray :: [Maybe Nat] -> Type -> Type
@@ -240,7 +241,7 @@ transpose2 ssh1 ssh2 (XArray arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
S.unScalar $
- numEltSum1Inner (SNat @0) $
+ liftO1 (numEltSum1Inner (SNat @0)) $
S.fromVector [product (S.shapeL arr)] $
S.toVector arr
@@ -256,7 +257,7 @@ sumInner ssh ssh' arr
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
, let sn = listxRank (let StaticShX l = ssh in l)
- = XArray (numEltSum1Inner sn arr')
+ = XArray (liftO1 (numEltSum1Inner sn) arr')
in go $
transpose2 ssh'F ssh $