From dc66969bc009714486da40254aa3eff3ea57b035 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 4 Aug 2025 18:05:05 +0200 Subject: Failed experiment to add replicate/transpose combination --- src/Data/Array/XArray.hs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'src/Data/Array/XArray.hs') diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index bf47622..f10e4f0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} @@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr) , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) +reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a +reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) -> + XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec))) + where + computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL + computeShape RTNil _ = [] + computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh + computeShape (RTRep n descr) sh = unNat n : computeShape descr sh + + computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int] + computeStrides RTNil _ = [] + computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str + computeStrides (RTRep _ descr) str = 0 : computeStrides descr str + sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = S.unScalar $ -- cgit v1.2.3-70-g09d2