aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/XArray.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r--src/Data/Array/XArray.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index bf47622..f10e4f0 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
@@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr)
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
+reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a
+reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) ->
+ XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec)))
+ where
+ computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL
+ computeShape RTNil _ = []
+ computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh
+ computeShape (RTRep n descr) sh = unNat n : computeShape descr sh
+
+ computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int]
+ computeStrides RTNil _ = []
+ computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str
+ computeStrides (RTRep _ descr) str = 0 : computeStrides descr str
+
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
S.unScalar $