diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 23:28:11 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 23:28:11 +0200 | 
| commit | 5f6a64660b16d8f188caca5216e55debf4264611 (patch) | |
| tree | 7e378c5929126db6c583862220e4163de7b2b3df /src/Data | |
| parent | 87484b9adcbaa6b380ed3ba1a499bd227a8863a8 (diff) | |
Add *flatten
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 5 | 
5 files changed, 24 insertions, 3 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 51f9fc0..b5c0772 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -10,7 +10,7 @@ module Data.Array.Nested (    rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,    rrerank,    rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, -  rslice, rrev1, rreshape, riota, +  rslice, rrev1, rreshape, rflatten, riota,    rminIndexPrim, rmaxIndexPrim, rdot,    rnest, runNest,    -- ** Lifting orthotope operations to 'Ranked' arrays @@ -30,7 +30,7 @@ module Data.Array.Nested (    -- TODO: sconcat? What should its type be?    srerank,    sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, -  sslice, srev1, sreshape, siota, +  sslice, srev1, sreshape, sflatten, siota,    sminIndexPrim, smaxIndexPrim, sdot,    snest, sunNest,    -- ** Lifting orthotope operations to 'Shaped' arrays @@ -47,7 +47,7 @@ module Data.Array.Nested (    mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,    mrerank,    mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, -  mslice, mrev1, mreshape, miota, +  mslice, mrev1, mreshape, mflatten, miota,    mminIndexPrim, mmaxIndexPrim, mdot,    mnest, munNest,    -- ** Lifting orthotope operations to 'Mixed' arrays diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 9c2096d..69df52a 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -796,6 +796,9 @@ mreshape sh' arr =          (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')          arr +mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a +mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr +  miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a  miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 3e9f528..59c1820 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -446,6 +446,9 @@ rreshape sh' rarr@(Ranked arr)    , Dict <- lemKnownReplicate (shrLengthSNat sh')    = Ranked (mreshape (shCvtRX sh') arr) +rflatten :: Elt a => Ranked n a -> Ranked 1 a +rflatten (Ranked arr) = mtoRanked (mflatten arr) +  riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a  riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index 9d718cc..7d95f61 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -6,6 +6,7 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -18,6 +19,7 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-}  {-# LANGUAGE ViewPatterns #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} @@ -475,6 +477,14 @@ shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))  shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)  shsPermutePrefix = coerce (listsPermutePrefix @SNat) +type family Product sh where +  Product '[] = 1 +  Product (n : ns) = n * Product ns + +shsProduct :: ShS sh -> SNat (Product sh) +shsProduct ZSS = SNat +shsProduct (n :$$ sh) = n `snatMul` shsProduct sh +  -- | Evidence for the static part of a shape. This pops up only when you are  -- polymorphic in the element type of an array.  type KnownShS :: [Nat] -> Constraint diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 863e604..1855015 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -370,6 +370,11 @@ srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr  sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a  sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) +sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a +sflatten arr = +  case shsProduct (sshape arr) of  -- TODO: simplify when removing the KnownNat stuff +    n@SNat -> sreshape (n :$$ ZSS) arr +  siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a  siota sn = Shaped (miota sn) | 
