aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 13:41:49 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 13:41:49 +0200
commit40dcdf2360c90437fd5c8f76f5f75c96203ef880 (patch)
treea006c444a20a20e7a430ae91ad9553314df35046 /src/Data/Array/Nested/Internal.hs
parentde25bf9ad34d823e9a6f5b0c6c82531586750e89 (diff)
Add append
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index f7c383a..209d594 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -15,6 +15,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-|
TODO:
@@ -29,6 +30,8 @@ TODO:
module Data.Array.Nested.Internal where
+import Prelude hiding (mappend)
+
import Control.Monad (forM_)
import Control.Monad.ST
import qualified Data.Array.RankedS as S
@@ -354,6 +357,13 @@ mtranspose perm =
mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
(X.transpose perm))
+mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
+mappend = mlift2 go
+ where go :: forall sh' b. (KnownShapeX sh', Storable b)
+ => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
+ go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -693,6 +703,10 @@ rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mtranspose perm arr)
+rappend :: forall n a. (KnownINat n, Elt a)
+ => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
+rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -782,3 +796,7 @@ stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Sha
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mtranspose perm arr)
+
+sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)
+ => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend