aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-10 23:28:11 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-10 23:28:11 +0200
commit5f6a64660b16d8f188caca5216e55debf4264611 (patch)
tree7e378c5929126db6c583862220e4163de7b2b3df /src
parent87484b9adcbaa6b380ed3ba1a499bd227a8863a8 (diff)
Add *flatten
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs10
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs5
5 files changed, 24 insertions, 3 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 51f9fc0..b5c0772 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -10,7 +10,7 @@ module Data.Array.Nested (
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rslice, rrev1, rreshape, riota,
+ rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot,
rnest, runNest,
-- ** Lifting orthotope operations to 'Ranked' arrays
@@ -30,7 +30,7 @@ module Data.Array.Nested (
-- TODO: sconcat? What should its type be?
srerank,
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sslice, srev1, sreshape, siota,
+ sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot,
snest, sunNest,
-- ** Lifting orthotope operations to 'Shaped' arrays
@@ -47,7 +47,7 @@ module Data.Array.Nested (
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mslice, mrev1, mreshape, miota,
+ mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot,
mnest, munNest,
-- ** Lifting orthotope operations to 'Mixed' arrays
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 9c2096d..69df52a 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -796,6 +796,9 @@ mreshape sh' arr =
(\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
arr
+mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
+mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr
+
miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 3e9f528..59c1820 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -446,6 +446,9 @@ rreshape sh' rarr@(Ranked arr)
, Dict <- lemKnownReplicate (shrLengthSNat sh')
= Ranked (mreshape (shCvtRX sh') arr)
+rflatten :: Elt a => Ranked n a -> Ranked 1 a
+rflatten (Ranked arr) = mtoRanked (mflatten arr)
+
riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 9d718cc..7d95f61 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -6,6 +6,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -18,6 +19,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
@@ -475,6 +477,14 @@ shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+type family Product sh where
+ Product '[] = 1
+ Product (n : ns) = n * Product ns
+
+shsProduct :: ShS sh -> SNat (Product sh)
+shsProduct ZSS = SNat
+shsProduct (n :$$ sh) = n `snatMul` shsProduct sh
+
-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
type KnownShS :: [Nat] -> Constraint
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 863e604..1855015 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -370,6 +370,11 @@ srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
+sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
+sflatten arr =
+ case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
+ n@SNat -> sreshape (n :$$ ZSS) arr
+
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)