diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 7 | ||||
-rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/XArray.hs | 15 |
4 files changed, 40 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 144230e..4028b1d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -5,6 +5,7 @@ {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE RankNTypes #-} @@ -287,6 +288,8 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) +type MRepTrans = RepTrans (SMayNat Int SNat) + -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. @@ -340,6 +343,8 @@ class Elt a where mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + mreptransPartial :: Proxy sh' -> MRepTrans sh1 sh2 -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + -- | All arrays in the input must have equal shapes, including subarrays -- inside their elements. mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a @@ -445,6 +450,8 @@ instance Storable a => Elt (Primitive a) where M_Primitive (shxPermutePrefix perm sh) (X.transpose (ssxFromShX sh) perm arr) + mreptransPartial p descr (M_Primitive sh arr) = _ + mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) mconcat l@(M_Primitive (_ :$% sh) _ :| _) = let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 03d1640..c893dac 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -281,3 +281,12 @@ lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerceRefl + + +-- * Replication-transpositions + +data RepTrans f sh sh' where + RTNil :: RepTrans f sh '[] + RTUse :: SNat i -> RepTrans f sh sh' -> RepTrans f sh (Index i sh : sh') + RTRep :: f n -> RepTrans f sh sh' -> RepTrans f sh (n : sh') + diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 198a068..2c64bb4 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -199,6 +200,14 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) srerank sh sh2 f (stoPrimitive -> arr) = sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr +-- data RepTrans sh sh' where +-- RTNil :: RepTrans sh '[] +-- RTUse :: SNat i -> RepTrans sh sh' -> RepTrans sh (Index i sh : sh') +-- RTRep :: SNat n -> RepTrans sh sh' -> RepTrans sh (n : sh') + +-- sreptrans :: RepTrans sh sh' -> Shaped sh a -> Shaped sh' a +-- sreptrans + sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index bf47622..f10e4f0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} @@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr) , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) +reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a +reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) -> + XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec))) + where + computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL + computeShape RTNil _ = [] + computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh + computeShape (RTRep n descr) sh = unNat n : computeShape descr sh + + computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int] + computeStrides RTNil _ = [] + computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str + computeStrides (RTRep _ descr) str = 0 : computeStrides descr str + sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = S.unScalar $ |