diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 85 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 31 | 
3 files changed, 106 insertions, 13 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 672b832..7f9076b 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-}  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE DeriveFoldable #-}  {-# LANGUAGE DeriveFunctor #-} @@ -7,6 +8,7 @@  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} @@ -14,6 +16,7 @@  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Mixed where @@ -22,9 +25,11 @@ import qualified Data.Array.Ranked as ORB  import Data.Coerce  import Data.Kind  import Data.Proxy +import Data.Type.Bool  import Data.Type.Equality  import qualified Data.Vector.Storable as VS  import Foreign.Storable (Storable) +import GHC.TypeError  import GHC.TypeLits  import Unsafe.Coerce (unsafeCoerce) @@ -343,17 +348,93 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)    , Dict <- lemKnownNatRankSSX ssh2    , Refl <- lemRankApp ssh ssh1    , Refl <- lemRankApp ssh ssh2 -  , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)      -- these two should be redundant but the +  , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)  -- should be redundant but the solver is not clever enough    = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)                  (\a b -> unXArray (f (XArray a) (XArray b)))                  arr1 arr2)    where      unXArray (XArray a) = a +type family Elem x l where +  Elem x '[] = 'False +  Elem x (x : _) = 'True +  Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where +  AllElem' '[] bs = 'True +  AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) +  (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where +  Count n n = '[] +  Count i n = i : Count (i + 1) n + +type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where +  Index 0 (n : sh) = n +  Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where +  Permute '[] sh = '[] +  Permute (i : is) sh = Index i sh : Permute is sh + +data HList f list where +  HNil :: HList f '[] +  HCons :: f a -> HList f l -> HList f (a : l) +infixr 5 `HCons` + +foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m +foldHList _ HNil = mempty +foldHList f (x `HCons` l) = f x <> foldHList f l + +class KnownNatList l where makeNatList :: HList SNat l +instance KnownNatList '[] where makeNatList = HNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList + +type family TakeLen ref l where +  TakeLen '[] l = '[] +  TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where +  DropLen '[] l = l +  DropLen (_ : ref) (_ : xs) = DropLen ref xs + +lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is +lemPermuteRank _ HNil = Refl +lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl + +lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh) +                => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemPermuteRank2 _ HNil = Refl +lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) = +  let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is' +      p1 = lemPermuteRank2 p is +      p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is') +      p9 = _ +  in p9 +  -- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +-- +-- This function does not throw: the constraints ensure that the permutation is always valid. +transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) +          => HList SNat is +          -> XArray sh a +          -> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a  transpose perm (XArray arr)    | Dict <- lemKnownNatRankSSX (knownShapeX @sh) +  , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm +  = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] +    in XArray (S.transpose perm' arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- This version throws a runtime error if the permutation is invalid. +transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transposeUntyped perm (XArray arr) +  | Dict <- lemKnownNatRankSSX (knownShapeX @sh)    = XArray (S.transpose perm arr)  transpose2 :: forall sh1 sh2 a. diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index c12d8ad..f451920 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -40,6 +40,9 @@ module Data.Array.Nested (    -- * Further utilities / re-exports    type (++),    Storable, +  HList, +  Permutation, +  makeNatList,  ) where  import Prelude hiding (mappend) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 627e0d3..b3f8143 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -22,7 +22,6 @@  {-# LANGUAGE ViewPatterns #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -{-# OPTIONS_GHC -Wno-unused-imports #-}  {-|  TODO: @@ -88,7 +87,7 @@ import Foreign.Storable (Storable)  import GHC.TypeLits  import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..)) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..))  import qualified Data.Array.Mixed as X @@ -192,6 +191,13 @@ lemRankReplicate _ = go (natSing @n)        , Refl <- go n        = Refl +lemRankMapJust :: forall sh. KnownShape sh => Proxy sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust _ = go (knownShape @sh) +  where +    go :: forall sh'. ShS sh' -> X.Rank (MapJust sh') :~: X.Rank sh' +    go ZSS = Refl +    go (_ :$$ sh') | Refl <- go sh' = Refl +  lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a                      -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a  lemReplicatePlusApp _ _ _ = go (natSing @n) @@ -577,10 +583,10 @@ mgenerate sh f = case X.enumShape sh of                    mvecsWrite sh idx val vecs                  mvecsFreeze sh vecs -mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a -mtranspose perm = -  mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') -                            (X.transpose perm)) +mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a +mtranspose perm = mlift $ \(Proxy @sh') -> +  X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') +    (X.transpose perm)  mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)          => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a @@ -1088,7 +1094,7 @@ rgenerate sh f  -- | See the documentation of 'mlift'.  rlift :: forall n1 n2 a. (KnownNat n2, Elt a) -      => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) +      => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)        -> Ranked n1 a -> Ranked n2 a  rlift f (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n2) @@ -1111,9 +1117,11 @@ rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n)  rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive  rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a -rtranspose perm (Ranked arr) +rtranspose perm    | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mtranspose perm arr) +  = rlift $ \(Proxy @sh') -> +      X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh') +        (X.transposeUntyped perm)  rappend :: forall n a. (KnownNat n, Elt a)          => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a @@ -1312,7 +1320,7 @@ sgenerate f  -- | See the documentation of 'mlift'.  slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) -      => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) +      => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)        -> Shaped sh1 a -> Shaped sh2 a  slift f (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh2) @@ -1334,9 +1342,10 @@ ssumOuter1 :: forall sh n a.             => Shaped (n : sh) a -> Shaped sh a  ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive -stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a  stranspose perm (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh) +  , Refl <- lemRankMapJust (Proxy @sh)    = Shaped (mtranspose perm arr)  sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) | 
