diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-15 19:24:39 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-15 21:21:36 +0200 |
commit | ac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (patch) | |
tree | 66c4a81ae66b6bb3d99b771067b8b3d55f6bffc1 /src/Data/Array/Mixed.hs | |
parent | e2c96efd486beeb7f690a468edec4e978c56f994 (diff) |
WIP stranspose type
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 85 |
1 files changed, 83 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 672b832..7f9076b 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} @@ -7,6 +8,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -14,6 +16,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where @@ -22,9 +25,11 @@ import qualified Data.Array.Ranked as ORB import Data.Coerce import Data.Kind import Data.Proxy +import Data.Type.Bool import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) +import GHC.TypeError import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) @@ -343,17 +348,93 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +data HList f list where + HNil :: HList f '[] + HCons :: f a -> HList f l -> HList f (a : l) +infixr 5 `HCons` + +foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m +foldHList _ HNil = mempty +foldHList f (x `HCons` l) = f x <> foldHList f l + +class KnownNatList l where makeNatList :: HList SNat l +instance KnownNatList '[] where makeNatList = HNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is +lemPermuteRank _ HNil = Refl +lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl + +lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh) + => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemPermuteRank2 _ HNil = Refl +lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) = + let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is' + p1 = lemPermuteRank2 p is + p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is') + p9 = _ + in p9 + -- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +-- +-- This function does not throw: the constraints ensure that the permutation is always valid. +transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) + => HList SNat is + -> XArray sh a + -> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a transpose perm (XArray arr) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm + = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] + in XArray (S.transpose perm' arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- This version throws a runtime error if the permutation is invalid. +transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transposeUntyped perm (XArray arr) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) = XArray (S.transpose perm arr) transpose2 :: forall sh1 sh2 a. |