diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-03 12:37:35 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-03 12:37:35 +0200 |
commit | 92902c4f66db111b439f3b7eba9de50ad7c73f7b (patch) | |
tree | 27f12853825b7dd13d4bc8040dd2be6781deb635 | |
parent | 264c8e601f49cebed9280f0da2e73f380bb5be52 (diff) |
Reorganise, documentation
-rw-r--r-- | ox-arrays.cabal | 9 | ||||
-rw-r--r-- | src/Data/Array/Mixed.hs (renamed from src/Array.hs) | 8 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 40 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs (renamed from src/Fancy.hs) | 85 | ||||
-rw-r--r-- | src/Data/Nat.hs (renamed from src/Nats.hs) | 22 |
5 files changed, 122 insertions, 42 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 2930ba0..5bdff7d 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -7,13 +7,14 @@ build-type: Simple library exposed-modules: - Array - Fancy - Nats + Data.Array.Mixed + Data.Array.Nested + Data.Array.Nested.Internal + Data.Nat build-depends: base >=4.18, ghc-typelits-knownnat, - ghc-typelits-natnormalise, + -- ghc-typelits-natnormalise, orthotope, vector hs-source-dirs: src diff --git a/src/Array.hs b/src/Data/Array/Mixed.hs index cbf04fc..e1e2d5a 100644 --- a/src/Array.hs +++ b/src/Data/Array/Mixed.hs @@ -9,8 +9,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -module Array where +module Data.Array.Mixed where import qualified Data.Array.RankedU as U import Data.Kind @@ -20,9 +19,10 @@ import qualified Data.Vector.Unboxed as VU import qualified GHC.TypeLits as GHC import Unsafe.Coerce (unsafeCoerce) -import Nats +import Data.Nat +-- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 @@ -41,6 +41,7 @@ data IxX sh where (::?) :: Int -> IxX sh -> IxX (Nothing : sh) deriving instance Show (IxX sh) +-- | The part of a shape that is statically known. type StaticShapeX :: [Maybe Nat] -> Type data StaticShapeX sh where SZX :: StaticShapeX '[] @@ -48,6 +49,7 @@ data StaticShapeX sh where (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) +-- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where knownShapeX :: StaticShapeX sh diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs new file mode 100644 index 0000000..983a636 --- /dev/null +++ b/src/Data/Array/Nested.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE ExplicitNamespaces #-} +module Data.Array.Nested ( + -- * Ranked arrays + Ranked, + IxR(..), + rshape, rindex, rindexPartial, rgenerate, rsumOuter1, + -- ** Lifting orthotope operations to 'Ranked' arrays + rlift, + + -- * Shaped arrays + Shaped, + IxS(..), + KnownShape(..), SShape(..), + sshape, sindex, sindexPartial, sgenerate, ssumOuter1, + -- ** Lifting orthotope operations to 'Shaped' arrays + slift, + + -- * Mixed arrays + Mixed, + IxX(..), + KnownShapeX(..), StaticShapeX(..), + mgenerate, + + -- * Array elements + Elt(mshape, mindex, mindexPartial, mlift), + Primitive(..), + + -- * Natural numbers + module Data.Nat, + + -- * Further utilities / re-exports + type (++), + VU.Unbox, +) where + +import qualified Data.Vector.Unboxed as VU + +import Data.Array.Mixed +import Data.Array.Nested.Internal +import Data.Nat diff --git a/src/Fancy.hs b/src/Data/Array/Nested/Internal.hs index 7461c1f..1139c57 100644 --- a/src/Fancy.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -23,7 +23,7 @@ TODO: -} -module Fancy where +module Data.Array.Nested.Internal where import Control.Monad (forM_) import Control.Monad.ST @@ -34,9 +34,9 @@ import Data.Type.Equality import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM -import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) -import qualified Array as X -import Nats +import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import qualified Data.Array.Mixed as X +import Data.Nat type family Replicate n a where @@ -79,11 +79,14 @@ newtype Primitive a = Primitive a -- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a dat afamily so that the full array is +-- over product-typed elements using a data family so that the full array is -- always in struct-of-arrays format. -- -- Built on top of 'XArray' which is built on top of @orthotope@, meaning that -- dimension permutations (e.g. 'transpose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a @@ -119,9 +122,9 @@ data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'GMixed' of @'Primitive' +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. -class GMixed a where +class Elt a where -- ====== PUBLIC METHODS ====== -- mshape :: KnownShapeX sh => Mixed sh a -> IxX sh @@ -162,7 +165,7 @@ class GMixed a where -- Arrays of scalars are basically just arrays of scalars. -instance VU.Unbox a => GMixed (Primitive a) where +instance VU.Unbox a => Elt (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) @@ -191,12 +194,12 @@ instance VU.Unbox a => GMixed (Primitive a) where mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v -- What a blessing that orthotope's Array has "representational" role on the value type! -deriving via Primitive Int instance GMixed Int -deriving via Primitive Double instance GMixed Double -deriving via Primitive () instance GMixed () +deriving via Primitive Int instance Elt Int +deriving via Primitive Double instance Elt Double +deriving via Primitive () instance Elt () -- Arrays of pairs are pairs of arrays. -instance (GMixed a, GMixed b) => GMixed (a, b) where +instance (Elt a, Elt b) => Elt (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) @@ -214,7 +217,7 @@ instance (GMixed a, GMixed b) => GMixed (a, b) where mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b -- Arrays of arrays are just arrays, but with more dimensions. -instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where +instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') @@ -277,7 +280,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -- Public method. Turns out this doesn't have to be in the type class! -- | Create an array given a size and a function that computes the element at a -- given index. -mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = @@ -305,13 +308,28 @@ mgenerate sh f checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' --- | Newtype around a 'Mixed' of 'Nothing's. This works like a rank-typed array --- as in @orthotope@. +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a +-- type-level natural that supports induction. +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) --- | Newtype around a 'Mixed' of 'Just's. This works like a shape-typed array --- as in @orthotope@. +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) @@ -326,7 +344,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. -instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where +instance (KnownNat n, Elt a) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) @@ -390,11 +408,13 @@ instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where vecs) +-- | The shape of a shape-typed array given as a list of 'SNat' values. data SShape sh where ShNil :: SShape '[] ShCons :: SNat n -> SShape sh -> SShape (n : sh) deriving instance Show (SShape sh) +-- | A statically-known shape of a shape-typed array. class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape @@ -414,7 +434,7 @@ lemMapJustPlusApp _ _ = go (knownShape @sh1) go ShNil = Refl go (ShCons _ sh) | Refl <- go sh = Refl -instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where +instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) @@ -501,28 +521,28 @@ ixCvtRX IZR = IZX ixCvtRX (n ::: idx) = n ::? ixCvtRX idx -rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n +rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = ixCvtXR (mshape arr) -rindex :: GMixed a => Ranked n a -> IxR n -> a +rindex :: Elt a => Ranked n a -> IxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a +rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) (ixCvtRX idx)) -rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a +rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a rgenerate sh f | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) -rlift :: forall n1 n2 a. (KnownNat n2, GMixed a) +rlift :: forall n1 n2 a. (KnownNat n2, Elt a) => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) @@ -544,6 +564,11 @@ rsumOuter1 (Ranked arr) -- ====== API OF SHAPED ARRAYS ====== -- -- | An index into a shape-typed array. +-- +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). Note that because the shape of a +-- shape-typed array is known statically, you can also retrieve the array shape +-- from a 'KnownShape' dictionary. type IxS :: [Nat] -> Type data IxS sh where IZS :: IxS '[] @@ -562,24 +587,24 @@ ixCvtSX IZS = IZX ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh -sshape :: forall sh a. (KnownShape sh, GMixed a) => Shaped sh a -> IxS sh +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh sshape _ = cvtSShapeIxS (knownShape @sh) -sindex :: GMixed a => Shaped sh a -> IxS sh -> a +sindex :: Elt a => Shaped sh a -> IxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, GMixed a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a +sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) -sgenerate :: forall sh a. (KnownShape sh, GMixed a) => IxS sh -> (IxS sh -> a) -> Shaped sh a +sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a sgenerate sh f | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) -slift :: forall sh1 sh2 a. (KnownShape sh2, GMixed a) +slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift f (Shaped arr) diff --git a/src/Nats.hs b/src/Data/Nat.hs index fdc090e..5dacc8a 100644 --- a/src/Nats.hs +++ b/src/Data/Nat.hs @@ -8,48 +8,60 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Nats where +module Data.Nat where import Data.Proxy import Numeric.Natural import qualified GHC.TypeLits as G +-- | Evidence for the constraint @c a@. data Dict c a where Dict :: c a => Dict c a +-- | A peano natural number. Intended to be used at the type level. data Nat = Z | S Nat deriving (Show) +-- | Singleton for a 'Nat'. data SNat n where SZ :: SNat Z SS :: SNat n -> SNat (S n) deriving instance Show (SNat n) +-- | A singleton 'SNat' corresponding to @n@. class KnownNat n where knownNat :: SNat n instance KnownNat Z where knownNat = SZ instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat -unSNat :: SNat n -> Natural -unSNat SZ = 0 -unSNat (SS n) = 1 + unSNat n - +-- | Convert a 'Nat' to a normal number. unNat :: Nat -> Natural unNat Z = 0 unNat (S n) = 1 + unNat n +-- | Convert an 'SNat' to a normal number. +unSNat :: SNat n -> Natural +unSNat SZ = 0 +unSNat (SS n) = 1 + unSNat n + +-- | A 'KnownNat' dictionary is just a singleton natural, so we can create +-- evidence of 'KnownNat' given an 'SNat'. snatKnown :: SNat n -> Dict KnownNat n snatKnown SZ = Dict snatKnown (SS n) | Dict <- snatKnown n = Dict +-- | Add two 'Nat's type family n + m where Z + m = m S n + m = S (n + m) +-- | Convert a 'Nat' to a "GHC.TypeLits" 'G.Nat'. type family GNat n where GNat Z = 0 GNat (S n) = 1 G.+ GNat n +-- | If an inductive 'Nat' is known, then the corresponding "GHC.TypeLits" +-- 'G.Nat' is also known. gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n) gknownNat (Proxy @n) = go (knownNat @n) where |