diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-08-04 18:05:05 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-08-04 18:05:05 +0200 |
commit | dc66969bc009714486da40254aa3eff3ea57b035 (patch) | |
tree | 7a671d0703048f2baa3890bba462d99454fe583b /src/Data/Array/XArray.hs | |
parent | 2fae6bf7f6704e3dd9a3f73acbdc84331adb1bf0 (diff) |
Failed experiment to add replicate/transpose combinationreptrans-failed
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r-- | src/Data/Array/XArray.hs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index bf47622..f10e4f0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} @@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr) , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) +reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a +reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) -> + XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec))) + where + computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL + computeShape RTNil _ = [] + computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh + computeShape (RTRep n descr) sh = unNat n : computeShape descr sh + + computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int] + computeStrides RTNil _ = [] + computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str + computeStrides (RTRep _ descr) str = 0 : computeStrides descr str + sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = S.unScalar $ |