diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 37 | 
1 files changed, 20 insertions, 17 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 840bb96..bdded69 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -4,6 +4,7 @@  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-} @@ -17,12 +18,13 @@  {-|  TODO: -* This module needs better structure with an Internal module and less public -  exports etc. -  * We should be more consistent in whether functions take a 'StaticShapeX'    argument or a 'KnownShapeX' constraint. +* Document the choice of using 'Nat' for ranks and 'GHC.Nat' for shapes. Point +  being that we need to do induction over the former, but the latter need to be +  able to get large. +  -}  module Data.Array.Nested.Internal where @@ -35,8 +37,9 @@ import Data.Proxy  import Data.Type.Equality  import qualified Data.Vector.Unboxed as VU  import qualified Data.Vector.Unboxed.Mutable as VUM +import qualified GHC.TypeLits as GHC -import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)  import qualified Data.Array.Mixed as X  import Data.Nat @@ -56,10 +59,10 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))      go SZ = SZX      go (SS n) = () :$? go n -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @GHC.Nat)) :~: n  lemRankReplicate _ = go (knownNat @n)    where -    go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m +    go :: SNat m -> X.Rank (Replicate m (Nothing @GHC.Nat)) :~: m      go SZ = Refl      go (SS n) | Refl <- go n = Refl @@ -89,7 +92,7 @@ newtype Primitive a = Primitive a  --  -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type  -- class. -type Mixed :: [Maybe Nat] -> Type -> Type +type Mixed :: [Maybe GHC.Nat] -> Type -> Type  data family Mixed sh a  newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) @@ -113,7 +116,7 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))  -- | Internal helper data family mirrorring 'Mixed' that consists of mutable  -- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type  data family MixedVecs s sh a  newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) @@ -311,7 +314,7 @@ mgenerate sh f    where      checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool      checkBounds IZX SZX = True -    checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' +    checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (GHC.fromSNat n') && checkBounds sh' ssh'      checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'  mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a @@ -343,7 +346,7 @@ deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)  -- and 'Shaped' itself is again an instance of 'Elt' as well.  --  -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type +type Shaped :: [GHC.Nat] -> Type -> Type  newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)  deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) @@ -427,18 +430,18 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where  -- | The shape of a shape-typed array given as a list of 'SNat' values.  data SShape sh where    ShNil :: SShape '[] -  ShCons :: SNat n -> SShape sh -> SShape (n : sh) +  ShCons :: GHC.SNat n -> SShape sh -> SShape (n : sh)  deriving instance Show (SShape sh)  infixr 5 `ShCons`  -- | A statically-known shape of a shape-typed array.  class KnownShape sh where knownShape :: SShape sh  instance KnownShape '[] where knownShape = ShNil -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape +instance (GHC.KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons GHC.natSing knownShape  sshapeKnown :: SShape sh -> Dict KnownShape sh  sshapeKnown ShNil = Dict -sshapeKnown (ShCons n sh) | Dict <- snatKnown n, Dict <- sshapeKnown sh = Dict +sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict  lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)  lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) @@ -596,7 +599,7 @@ rtranspose perm (Ranked arr)  -- (traditionally called \"@Fin@\"). Note that because the shape of a  -- shape-typed array is known statically, you can also retrieve the array shape  -- from a 'KnownShape' dictionary. -type IxS :: [Nat] -> Type +type IxS :: [GHC.Nat] -> Type  data IxS sh where    IZS :: IxS '[]    (::$) :: Int -> IxS sh -> IxS (n : sh) @@ -604,7 +607,7 @@ infixr 5 ::$  cvtSShapeIxS :: SShape sh -> IxS sh  cvtSShapeIxS ShNil = IZS -cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh +cvtSShapeIxS (ShCons n sh) = fromIntegral (GHC.fromSNat n) ::$ cvtSShapeIxS sh  ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh  ixCvtXS ShNil IZX = IZS @@ -640,13 +643,13 @@ slift f (Shaped arr)    = Shaped (mlift f arr)  ssumOuter1 :: forall sh n a. -              (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) +              (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))             => Shaped (n : sh) a -> Shaped sh a  ssumOuter1 (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh)    = Shaped      . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) -    . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) +    . X.sumOuter (GHC.natSing @n :$@ SZX) (knownShapeX @(MapJust sh))      . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)      $ arr | 
