aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-15 19:24:39 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-15 21:21:36 +0200
commitac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (patch)
tree66c4a81ae66b6bb3d99b771067b8b3d55f6bffc1 /src/Data
parente2c96efd486beeb7f690a468edec4e978c56f994 (diff)
WIP stranspose type
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed.hs85
-rw-r--r--src/Data/Array/Nested.hs3
-rw-r--r--src/Data/Array/Nested/Internal.hs31
3 files changed, 106 insertions, 13 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 672b832..7f9076b 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
@@ -7,6 +8,7 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -14,6 +16,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -22,9 +25,11 @@ import qualified Data.Array.Ranked as ORB
import Data.Coerce
import Data.Kind
import Data.Proxy
+import Data.Type.Bool
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
import Foreign.Storable (Storable)
+import GHC.TypeError
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
@@ -343,17 +348,93 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
, Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
= XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
unXArray (XArray a) = a
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+data HList f list where
+ HNil :: HList f '[]
+ HCons :: f a -> HList f l -> HList f (a : l)
+infixr 5 `HCons`
+
+foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m
+foldHList _ HNil = mempty
+foldHList f (x `HCons` l) = f x <> foldHList f l
+
+class KnownNatList l where makeNatList :: HList SNat l
+instance KnownNatList '[] where makeNatList = HNil
+instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
+lemPermuteRank _ HNil = Refl
+lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl
+
+lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh)
+ => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemPermuteRank2 _ HNil = Refl
+lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) =
+ let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is'
+ p1 = lemPermuteRank2 p is
+ p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is')
+ p9 = _
+ in p9
+
-- | The list argument gives indices into the original dimension list.
-transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+--
+-- This function does not throw: the constraints ensure that the permutation is always valid.
+transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
+ => HList SNat is
+ -> XArray sh a
+ -> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm
+ = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
+ in XArray (S.transpose perm' arr)
+
+-- | The list argument gives indices into the original dimension list.
+--
+-- This version throws a runtime error if the permutation is invalid.
+transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transposeUntyped perm (XArray arr)
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index c12d8ad..f451920 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -40,6 +40,9 @@ module Data.Array.Nested (
-- * Further utilities / re-exports
type (++),
Storable,
+ HList,
+ Permutation,
+ makeNatList,
) where
import Prelude hiding (mappend)
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 627e0d3..b3f8143 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -22,7 +22,6 @@
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-{-# OPTIONS_GHC -Wno-unused-imports #-}
{-|
TODO:
@@ -88,7 +87,7 @@ import Foreign.Storable (Storable)
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..))
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..))
import qualified Data.Array.Mixed as X
@@ -192,6 +191,13 @@ lemRankReplicate _ = go (natSing @n)
, Refl <- go n
= Refl
+lemRankMapJust :: forall sh. KnownShape sh => Proxy sh -> X.Rank (MapJust sh) :~: X.Rank sh
+lemRankMapJust _ = go (knownShape @sh)
+ where
+ go :: forall sh'. ShS sh' -> X.Rank (MapJust sh') :~: X.Rank sh'
+ go ZSS = Refl
+ go (_ :$$ sh') | Refl <- go sh' = Refl
+
lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp _ _ _ = go (natSing @n)
@@ -577,10 +583,10 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
-mtranspose perm =
- mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
- (X.transpose perm))
+mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a
+mtranspose perm = mlift $ \(Proxy @sh') ->
+ X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
+ (X.transpose perm)
mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
@@ -1088,7 +1094,7 @@ rgenerate sh f
-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
- => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n2)
@@ -1111,9 +1117,11 @@ rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n)
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
-rtranspose perm (Ranked arr)
+rtranspose perm
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mtranspose perm arr)
+ = rlift $ \(Proxy @sh') ->
+ X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh')
+ (X.transposeUntyped perm)
rappend :: forall n a. (KnownNat n, Elt a)
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
@@ -1312,7 +1320,7 @@ sgenerate f
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
- => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
slift f (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh2)
@@ -1334,9 +1342,10 @@ ssumOuter1 :: forall sh n a.
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
-stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
+ , Refl <- lemRankMapJust (Proxy @sh)
= Shaped (mtranspose perm arr)
sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)