aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Mixed.hs7
-rw-r--r--src/Data/Array/Nested/Permutation.hs9
-rw-r--r--src/Data/Array/Nested/Shaped.hs9
-rw-r--r--src/Data/Array/XArray.hs15
4 files changed, 40 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 144230e..4028b1d 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE RankNTypes #-}
@@ -287,6 +288,8 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem)
matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
+type MRepTrans = RepTrans (SMayNat Int SNat)
+
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
@@ -340,6 +343,8 @@ class Elt a where
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+ mreptransPartial :: Proxy sh' -> MRepTrans sh1 sh2 -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+
-- | All arrays in the input must have equal shapes, including subarrays
-- inside their elements.
mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a
@@ -445,6 +450,8 @@ instance Storable a => Elt (Primitive a) where
M_Primitive (shxPermutePrefix perm sh)
(X.transpose (ssxFromShX sh) perm arr)
+ mreptransPartial p descr (M_Primitive sh arr) = _
+
mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l)
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 03d1640..c893dac 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -281,3 +281,12 @@ lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
-> Index (i + 1) (a : l) :~: Index i l
lemIndexSucc _ _ _ = unsafeCoerceRefl
+
+
+-- * Replication-transpositions
+
+data RepTrans f sh sh' where
+ RTNil :: RepTrans f sh '[]
+ RTUse :: SNat i -> RepTrans f sh sh' -> RepTrans f sh (Index i sh : sh')
+ RTRep :: f n -> RepTrans f sh sh' -> RepTrans f sh (n : sh')
+
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 198a068..2c64bb4 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -199,6 +200,14 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
srerank sh sh2 f (stoPrimitive -> arr) =
sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+-- data RepTrans sh sh' where
+-- RTNil :: RepTrans sh '[]
+-- RTUse :: SNat i -> RepTrans sh sh' -> RepTrans sh (Index i sh : sh')
+-- RTRep :: SNat n -> RepTrans sh sh' -> RepTrans sh (n : sh')
+
+-- sreptrans :: RepTrans sh sh' -> Shaped sh a -> Shaped sh' a
+-- sreptrans
+
sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
sreplicate sh (Shaped arr)
| Refl <- lemMapJustApp sh (Proxy @sh')
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index bf47622..f10e4f0 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
@@ -245,6 +246,20 @@ transpose2 ssh1 ssh2 (XArray arr)
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
+reptransPartial :: forall f sh sh' a. (forall n. f n -> Int) -> RepTrans f sh sh' -> XArray sh a -> XArray sh' a
+reptransPartial unNat = \descr (XArray (ORS.A (ORG.A sh (OI.T strides off vec)))) ->
+ XArray (ORS.A (ORG.A (computeShape descr sh) (OI.T (computeStrides descr strides) off vec)))
+ where
+ computeShape :: RepTrans f sh1 sh2 -> S.ShapeL -> S.ShapeL
+ computeShape RTNil _ = []
+ computeShape (RTUse idx descr) sh = sh !! fromSNat' idx : computeShape descr sh
+ computeShape (RTRep n descr) sh = unNat n : computeShape descr sh
+
+ computeStrides :: RepTrans f sh1 sh2 -> [Int] -> [Int]
+ computeStrides RTNil _ = []
+ computeStrides (RTUse idx descr) str = str !! fromSNat' idx : computeStrides descr str
+ computeStrides (RTRep _ descr) str = 0 : computeStrides descr str
+
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
S.unScalar $