diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-15 19:24:39 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-15 21:21:36 +0200 |
commit | ac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (patch) | |
tree | 66c4a81ae66b6bb3d99b771067b8b3d55f6bffc1 /src/Data/Array/Nested | |
parent | e2c96efd486beeb7f690a468edec4e978c56f994 (diff) |
WIP stranspose type
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 31 |
1 files changed, 20 insertions, 11 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 627e0d3..b3f8143 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -22,7 +22,6 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -{-# OPTIONS_GHC -Wno-unused-imports #-} {-| TODO: @@ -88,7 +87,7 @@ import Foreign.Storable (Storable) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..)) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..)) import qualified Data.Array.Mixed as X @@ -192,6 +191,13 @@ lemRankReplicate _ = go (natSing @n) , Refl <- go n = Refl +lemRankMapJust :: forall sh. KnownShape sh => Proxy sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust _ = go (knownShape @sh) + where + go :: forall sh'. ShS sh' -> X.Rank (MapJust sh') :~: X.Rank sh' + go ZSS = Refl + go (_ :$$ sh') | Refl <- go sh' = Refl + lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a lemReplicatePlusApp _ _ _ = go (natSing @n) @@ -577,10 +583,10 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a -mtranspose perm = - mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') - (X.transpose perm)) +mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a +mtranspose perm = mlift $ \(Proxy @sh') -> + X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') + (X.transpose perm) mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a @@ -1088,7 +1094,7 @@ rgenerate sh f -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. (KnownNat n2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) @@ -1111,9 +1117,11 @@ rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n) rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a -rtranspose perm (Ranked arr) +rtranspose perm | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mtranspose perm arr) + = rlift $ \(Proxy @sh') -> + X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh') + (X.transposeUntyped perm) rappend :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a @@ -1312,7 +1320,7 @@ sgenerate f -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift f (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh2) @@ -1334,9 +1342,10 @@ ssumOuter1 :: forall sh n a. => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive -stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) + , Refl <- lemRankMapJust (Proxy @sh) = Shaped (mtranspose perm arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) |