diff options
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r-- | src/Data/Array/XArray.hs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index bf47622..f10e4f0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} @@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr) , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) +reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a +reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) -> + XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec))) + where + computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL + computeShape RTNil _ = [] + computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh + computeShape (RTRep n descr) sh = unNat n : computeShape descr sh + + computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int] + computeStrides RTNil _ = [] + computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str + computeStrides (RTRep _ descr) str = 0 : computeStrides descr str + sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = S.unScalar $ |