diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-08-04 18:05:05 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-08-04 18:05:05 +0200 | 
| commit | dc66969bc009714486da40254aa3eff3ea57b035 (patch) | |
| tree | 7a671d0703048f2baa3890bba462d99454fe583b /src | |
| parent | 2fae6bf7f6704e3dd9a3f73acbdc84331adb1bf0 (diff) | |
Failed experiment to add replicate/transpose combinationreptrans-failed
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 7 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 15 | 
4 files changed, 40 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 144230e..4028b1d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -5,6 +5,7 @@  {-# LANGUAGE DerivingVia #-}  {-# LANGUAGE FlexibleContexts #-}  {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE RankNTypes #-} @@ -287,6 +288,8 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem)  matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a  matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) +type MRepTrans = RepTrans (SMayNat Int SNat) +  -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or  -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'  -- a@; see the documentation for 'Primitive' for more details. @@ -340,6 +343,8 @@ class Elt a where    mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)               => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a +  mreptransPartial :: Proxy sh' -> MRepTrans sh1 sh2 -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a +    -- | All arrays in the input must have equal shapes, including subarrays    -- inside their elements.    mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a @@ -445,6 +450,8 @@ instance Storable a => Elt (Primitive a) where      M_Primitive (shxPermutePrefix perm sh)                  (X.transpose (ssxFromShX sh) perm arr) +  mreptransPartial p descr (M_Primitive sh arr) = _ +    mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)    mconcat l@(M_Primitive (_ :$% sh) _ :| _) =      let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 03d1640..c893dac 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -281,3 +281,12 @@ lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"  lemIndexSucc :: Proxy i -> Proxy a -> Proxy l               -> Index (i + 1) (a : l) :~: Index i l  lemIndexSucc _ _ _ = unsafeCoerceRefl + + +-- * Replication-transpositions + +data RepTrans f sh sh' where +  RTNil :: RepTrans f sh '[] +  RTUse :: SNat i -> RepTrans f sh sh' -> RepTrans f sh (Index i sh : sh') +  RTRep :: f n -> RepTrans f sh sh' -> RepTrans f sh (n : sh') + diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 198a068..2c64bb4 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -1,4 +1,5 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-} @@ -199,6 +200,14 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)  srerank sh sh2 f (stoPrimitive -> arr) =    sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr +-- data RepTrans sh sh' where +--   RTNil :: RepTrans sh '[] +--   RTUse :: SNat i -> RepTrans sh sh' -> RepTrans sh (Index i sh : sh') +--   RTRep :: SNat n -> RepTrans sh sh' -> RepTrans sh (n : sh') + +-- sreptrans :: RepTrans sh sh' -> Shaped sh a -> Shaped sh' a +-- sreptrans  +  sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a  sreplicate sh (Shaped arr)    | Refl <- lemMapJustApp sh (Proxy @sh') diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index bf47622..f10e4f0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -4,6 +4,7 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneKindSignatures #-}  {-# LANGUAGE StrictData #-} @@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr)    , let n1 = ssxLength ssh1    = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) +reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a +reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) -> +  XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec))) +  where +    computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL +    computeShape RTNil _ = [] +    computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh +    computeShape (RTRep n descr) sh = unNat n : computeShape descr sh + +    computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int] +    computeStrides RTNil _ = [] +    computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str +    computeStrides (RTRep _ descr) str = 0 : computeStrides descr str +  sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a  sumFull _ (XArray arr) =    S.unScalar $  | 
