aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/XArray.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-08-04 18:05:05 +0200
committerTom Smeding <t.j.smeding@uu.nl>2025-08-04 18:05:05 +0200
commitdc66969bc009714486da40254aa3eff3ea57b035 (patch)
tree7a671d0703048f2baa3890bba462d99454fe583b /src/Data/Array/XArray.hs
parent2fae6bf7f6704e3dd9a3f73acbdc84331adb1bf0 (diff)
Failed experiment to add replicate/transpose combinationreptrans-failed
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r--src/Data/Array/XArray.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index bf47622..f10e4f0 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
@@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr)
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
+reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a
+reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) ->
+ XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec)))
+ where
+ computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL
+ computeShape RTNil _ = []
+ computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh
+ computeShape (RTRep n descr) sh = unNat n : computeShape descr sh
+
+ computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int]
+ computeStrides RTNil _ = []
+ computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str
+ computeStrides (RTRep _ descr) str = 0 : computeStrides descr str
+
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
S.unScalar $