aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ox-arrays.cabal9
-rw-r--r--src/Data/Array/Mixed.hs (renamed from src/Array.hs)8
-rw-r--r--src/Data/Array/Nested.hs40
-rw-r--r--src/Data/Array/Nested/Internal.hs (renamed from src/Fancy.hs)85
-rw-r--r--src/Data/Nat.hs (renamed from src/Nats.hs)22
5 files changed, 122 insertions, 42 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 2930ba0..5bdff7d 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -7,13 +7,14 @@ build-type: Simple
library
exposed-modules:
- Array
- Fancy
- Nats
+ Data.Array.Mixed
+ Data.Array.Nested
+ Data.Array.Nested.Internal
+ Data.Nat
build-depends:
base >=4.18,
ghc-typelits-knownnat,
- ghc-typelits-natnormalise,
+ -- ghc-typelits-natnormalise,
orthotope,
vector
hs-source-dirs: src
diff --git a/src/Array.hs b/src/Data/Array/Mixed.hs
index cbf04fc..e1e2d5a 100644
--- a/src/Array.hs
+++ b/src/Data/Array/Mixed.hs
@@ -9,8 +9,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-module Array where
+module Data.Array.Mixed where
import qualified Data.Array.RankedU as U
import Data.Kind
@@ -20,9 +19,10 @@ import qualified Data.Vector.Unboxed as VU
import qualified GHC.TypeLits as GHC
import Unsafe.Coerce (unsafeCoerce)
-import Nats
+import Data.Nat
+-- | Type-level list append.
type family l1 ++ l2 where
'[] ++ l2 = l2
(x : xs) ++ l2 = x : xs ++ l2
@@ -41,6 +41,7 @@ data IxX sh where
(::?) :: Int -> IxX sh -> IxX (Nothing : sh)
deriving instance Show (IxX sh)
+-- | The part of a shape that is statically known.
type StaticShapeX :: [Maybe Nat] -> Type
data StaticShapeX sh where
SZX :: StaticShapeX '[]
@@ -48,6 +49,7 @@ data StaticShapeX sh where
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
+-- | Evidence for the static part of a shape.
type KnownShapeX :: [Maybe Nat] -> Constraint
class KnownShapeX sh where
knownShapeX :: StaticShapeX sh
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
new file mode 100644
index 0000000..983a636
--- /dev/null
+++ b/src/Data/Array/Nested.hs
@@ -0,0 +1,40 @@
+{-# LANGUAGE ExplicitNamespaces #-}
+module Data.Array.Nested (
+ -- * Ranked arrays
+ Ranked,
+ IxR(..),
+ rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
+ -- ** Lifting orthotope operations to 'Ranked' arrays
+ rlift,
+
+ -- * Shaped arrays
+ Shaped,
+ IxS(..),
+ KnownShape(..), SShape(..),
+ sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
+ -- ** Lifting orthotope operations to 'Shaped' arrays
+ slift,
+
+ -- * Mixed arrays
+ Mixed,
+ IxX(..),
+ KnownShapeX(..), StaticShapeX(..),
+ mgenerate,
+
+ -- * Array elements
+ Elt(mshape, mindex, mindexPartial, mlift),
+ Primitive(..),
+
+ -- * Natural numbers
+ module Data.Nat,
+
+ -- * Further utilities / re-exports
+ type (++),
+ VU.Unbox,
+) where
+
+import qualified Data.Vector.Unboxed as VU
+
+import Data.Array.Mixed
+import Data.Array.Nested.Internal
+import Data.Nat
diff --git a/src/Fancy.hs b/src/Data/Array/Nested/Internal.hs
index 7461c1f..1139c57 100644
--- a/src/Fancy.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -23,7 +23,7 @@ TODO:
-}
-module Fancy where
+module Data.Array.Nested.Internal where
import Control.Monad (forM_)
import Control.Monad.ST
@@ -34,9 +34,9 @@ import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
-import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
-import qualified Array as X
-import Nats
+import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
+import qualified Data.Array.Mixed as X
+import Data.Nat
type family Replicate n a where
@@ -79,11 +79,14 @@ newtype Primitive a = Primitive a
-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
--- over product-typed elements using a dat afamily so that the full array is
+-- over product-typed elements using a data family so that the full array is
-- always in struct-of-arrays format.
--
-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
-- dimension permutations (e.g. 'transpose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
@@ -119,9 +122,9 @@ data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh
-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'GMixed' of @'Primitive'
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
-class GMixed a where
+class Elt a where
-- ====== PUBLIC METHODS ====== --
mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
@@ -162,7 +165,7 @@ class GMixed a where
-- Arrays of scalars are basically just arrays of scalars.
-instance VU.Unbox a => GMixed (Primitive a) where
+instance VU.Unbox a => Elt (Primitive a) where
mshape (M_Primitive a) = X.shape a
mindex (M_Primitive a) i = Primitive (X.index a i)
mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
@@ -191,12 +194,12 @@ instance VU.Unbox a => GMixed (Primitive a) where
mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
-- What a blessing that orthotope's Array has "representational" role on the value type!
-deriving via Primitive Int instance GMixed Int
-deriving via Primitive Double instance GMixed Double
-deriving via Primitive () instance GMixed ()
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Double instance Elt Double
+deriving via Primitive () instance Elt ()
-- Arrays of pairs are pairs of arrays.
-instance (GMixed a, GMixed b) => GMixed (a, b) where
+instance (Elt a, Elt b) => Elt (a, b) where
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
@@ -214,7 +217,7 @@ instance (GMixed a, GMixed b) => GMixed (a, b) where
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
-- Arrays of arrays are just arrays, but with more dimensions.
-instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
+instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
@@ -277,7 +280,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
-- Public method. Turns out this doesn't have to be in the type class!
-- | Create an array given a size and a function that computes the element at a
-- given index.
-mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
-- TODO: Do we need this checkBounds check elsewhere as well?
| not (checkBounds sh (knownShapeX @sh)) =
@@ -305,13 +308,28 @@ mgenerate sh f
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
--- | Newtype around a 'Mixed' of 'Nothing's. This works like a rank-typed array
--- as in @orthotope@.
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a
+-- type-level natural that supports induction.
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
--- | Newtype around a 'Mixed' of 'Just's. This works like a shape-typed array
--- as in @orthotope@.
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
@@ -326,7 +344,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
-instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
+instance (KnownNat n, Elt a) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
@@ -390,11 +408,13 @@ instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
vecs)
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
data SShape sh where
ShNil :: SShape '[]
ShCons :: SNat n -> SShape sh -> SShape (n : sh)
deriving instance Show (SShape sh)
+-- | A statically-known shape of a shape-typed array.
class KnownShape sh where knownShape :: SShape sh
instance KnownShape '[] where knownShape = ShNil
instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
@@ -414,7 +434,7 @@ lemMapJustPlusApp _ _ = go (knownShape @sh1)
go ShNil = Refl
go (ShCons _ sh) | Refl <- go sh = Refl
-instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where
+instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
@@ -501,28 +521,28 @@ ixCvtRX IZR = IZX
ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
-rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= ixCvtXR (mshape arr)
-rindex :: GMixed a => Ranked n a -> IxR n -> a
+rindex :: Elt a => Ranked n a -> IxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
(ixCvtRX idx))
-rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a
rgenerate sh f
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
-rlift :: forall n1 n2 a. (KnownNat n2, GMixed a)
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
=> (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
@@ -544,6 +564,11 @@ rsumOuter1 (Ranked arr)
-- ====== API OF SHAPED ARRAYS ====== --
-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
type IxS :: [Nat] -> Type
data IxS sh where
IZS :: IxS '[]
@@ -562,24 +587,24 @@ ixCvtSX IZS = IZX
ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
-sshape :: forall sh a. (KnownShape sh, GMixed a) => Shaped sh a -> IxS sh
+sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
sshape _ = cvtSShapeIxS (knownShape @sh)
-sindex :: GMixed a => Shaped sh a -> IxS sh -> a
+sindex :: Elt a => Shaped sh a -> IxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, GMixed a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
+sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
sindexPartial (Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
(rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
-sgenerate :: forall sh a. (KnownShape sh, GMixed a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
sgenerate sh f
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
-slift :: forall sh1 sh2 a. (KnownShape sh2, GMixed a)
+slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
=> (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
slift f (Shaped arr)
diff --git a/src/Nats.hs b/src/Data/Nat.hs
index fdc090e..5dacc8a 100644
--- a/src/Nats.hs
+++ b/src/Data/Nat.hs
@@ -8,48 +8,60 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Nats where
+module Data.Nat where
import Data.Proxy
import Numeric.Natural
import qualified GHC.TypeLits as G
+-- | Evidence for the constraint @c a@.
data Dict c a where
Dict :: c a => Dict c a
+-- | A peano natural number. Intended to be used at the type level.
data Nat = Z | S Nat
deriving (Show)
+-- | Singleton for a 'Nat'.
data SNat n where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
deriving instance Show (SNat n)
+-- | A singleton 'SNat' corresponding to @n@.
class KnownNat n where knownNat :: SNat n
instance KnownNat Z where knownNat = SZ
instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
-unSNat :: SNat n -> Natural
-unSNat SZ = 0
-unSNat (SS n) = 1 + unSNat n
-
+-- | Convert a 'Nat' to a normal number.
unNat :: Nat -> Natural
unNat Z = 0
unNat (S n) = 1 + unNat n
+-- | Convert an 'SNat' to a normal number.
+unSNat :: SNat n -> Natural
+unSNat SZ = 0
+unSNat (SS n) = 1 + unSNat n
+
+-- | A 'KnownNat' dictionary is just a singleton natural, so we can create
+-- evidence of 'KnownNat' given an 'SNat'.
snatKnown :: SNat n -> Dict KnownNat n
snatKnown SZ = Dict
snatKnown (SS n) | Dict <- snatKnown n = Dict
+-- | Add two 'Nat's
type family n + m where
Z + m = m
S n + m = S (n + m)
+-- | Convert a 'Nat' to a "GHC.TypeLits" 'G.Nat'.
type family GNat n where
GNat Z = 0
GNat (S n) = 1 G.+ GNat n
+-- | If an inductive 'Nat' is known, then the corresponding "GHC.TypeLits"
+-- 'G.Nat' is also known.
gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n)
gknownNat (Proxy @n) = go (knownNat @n)
where