aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-04-03 12:37:35 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-04-03 12:37:35 +0200
commit92902c4f66db111b439f3b7eba9de50ad7c73f7b (patch)
tree27f12853825b7dd13d4bc8040dd2be6781deb635 /src/Data
parent264c8e601f49cebed9280f0da2e73f380bb5be52 (diff)
Reorganise, documentation
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed.hs314
-rw-r--r--src/Data/Array/Nested.hs40
-rw-r--r--src/Data/Array/Nested/Internal.hs623
-rw-r--r--src/Data/Nat.hs70
4 files changed, 1047 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
new file mode 100644
index 0000000..e1e2d5a
--- /dev/null
+++ b/src/Data/Array/Mixed.hs
@@ -0,0 +1,314 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed where
+
+import qualified Data.Array.RankedU as U
+import Data.Kind
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Unboxed as VU
+import qualified GHC.TypeLits as GHC
+import Unsafe.Coerce (unsafeCoerce)
+
+import Data.Nat
+
+
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerce Refl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerce Refl
+
+
+type IxX :: [Maybe Nat] -> Type
+data IxX sh where
+ IZX :: IxX '[]
+ (::@) :: Int -> IxX sh -> IxX (Just n : sh)
+ (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+deriving instance Show (IxX sh)
+
+-- | The part of a shape that is statically known.
+type StaticShapeX :: [Maybe Nat] -> Type
+data StaticShapeX sh where
+ SZX :: StaticShapeX '[]
+ (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
+ (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
+deriving instance Show (StaticShapeX sh)
+
+-- | Evidence for the static part of a shape.
+type KnownShapeX :: [Maybe Nat] -> Constraint
+class KnownShapeX sh where
+ knownShapeX :: StaticShapeX sh
+instance KnownShapeX '[] where
+ knownShapeX = SZX
+instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
+ knownShapeX = knownNat :$@ knownShapeX
+instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
+ knownShapeX = () :$? knownShapeX
+
+type family Rank sh where
+ Rank '[] = Z
+ Rank (_ : sh) = S (Rank sh)
+
+type XArray :: [Maybe Nat] -> Type -> Type
+data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
+
+zeroIdx :: StaticShapeX sh -> IxX sh
+zeroIdx SZX = IZX
+zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh
+zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh
+
+zeroIdx' :: IxX sh -> IxX sh
+zeroIdx' IZX = IZX
+zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh
+zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh
+
+ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh')
+ixAppend IZX idx' = idx'
+ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx'
+ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx'
+
+ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh'
+ixDrop sh IZX = sh
+ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx
+ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx
+
+ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh')
+ssxAppend SZX idx' = idx'
+ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx'
+ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx'
+
+shapeSize :: IxX sh -> Int
+shapeSize IZX = 1
+shapeSize (n ::@ sh) = n * shapeSize sh
+shapeSize (n ::? sh) = n * shapeSize sh
+
+fromLinearIdx :: IxX sh -> Int -> IxX sh
+fromLinearIdx = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "fromLinearIdx: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IxX sh -> Int -> (IxX sh, Int)
+ go IZX i = (IZX, i)
+ go (n ::@ sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` n
+ in (locali ::@ idx, upi)
+ go (n ::? sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` n
+ in (locali ::? idx, upi)
+
+toLinearIdx :: IxX sh -> IxX sh -> Int
+toLinearIdx = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IxX sh -> IxX sh -> (Int, Int)
+ go IZX IZX = (0, 1)
+ go (n ::@ sh) (i ::@ ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, n * sz)
+ go (n ::? sh) (i ::? ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, n * sz)
+
+enumShape :: IxX sh -> [IxX sh]
+enumShape = \sh -> go 0 sh id []
+ where
+ go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a]
+ go _ IZX _ = id
+ go i (n ::@ sh) f
+ | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@))
+ | otherwise = id
+ go i (n ::? sh) f
+ | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?))
+ | otherwise = id
+
+shapeLshape :: IxX sh -> U.ShapeL
+shapeLshape IZX = []
+shapeLshape (n ::@ sh) = n : shapeLshape sh
+shapeLshape (n ::? sh) = n : shapeLshape sh
+
+ssxLength :: StaticShapeX sh -> Int
+ssxLength SZX = 0
+ssxLength (_ :$@ ssh) = 1 + ssxLength ssh
+ssxLength (_ :$? ssh) = 1 + ssxLength ssh
+
+ssxIotaFrom :: Int -> StaticShapeX sh -> [Int]
+ssxIotaFrom _ SZX = []
+ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
+
+lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2
+ -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2)
+lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
+
+lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
+ -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1))
+lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
+
+lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank IZX = Dict
+lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX SZX = Dict
+lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+
+lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
+lemKnownShapeX SZX = Dict
+lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict
+lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
+
+lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
+lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
+lemAppKnownShapeX (n :$@ ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ , Dict <- snatKnown n
+ = Dict
+lemAppKnownShapeX (() :$? ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ = Dict
+
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
+shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
+ where
+ go :: StaticShapeX sh' -> [Int] -> IxX sh'
+ go SZX [] = IZX
+ go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l
+ go (() :$? ssh) (n : l) = n ::? go ssh l
+ go _ _ = error "Invalid shapeL"
+
+fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector sh v
+ | Dict <- lemKnownNatRank sh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.fromVector (shapeLshape sh) v)
+
+toVector :: U.Unbox a => XArray sh a -> VU.Vector a
+toVector (XArray arr) = U.toVector arr
+
+scalar :: U.Unbox a => a -> XArray '[] a
+scalar = XArray . U.scalar
+
+unScalar :: U.Unbox a => XArray '[] a -> a
+unScalar (XArray a) = U.unScalar a
+
+generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh)
+
+-- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM sh f | Dict <- lemKnownNatRank sh =
+-- XArray . U.fromVector (shapeLshape sh)
+-- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh)
+
+indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial (XArray arr) IZX = XArray arr
+indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx
+indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx
+
+index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a
+index xarr i
+ | Refl <- lemAppNil @sh
+ = let XArray arr' = indexPartial xarr i :: XArray '[] a
+ in U.unScalar arr'
+
+append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
+append (XArray a) (XArray b)
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.append a b)
+
+rerank :: forall sh sh1 sh2 a b.
+ (U.Unbox a, U.Unbox b)
+ => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ -> (XArray sh1 a -> XArray sh2 b)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
+rerank ssh ssh1 ssh2 f (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ (\a -> unXArray (f (XArray a)))
+ arr)
+ where
+ unXArray (XArray a) = a
+
+rerank2 :: forall sh sh1 sh2 a b c.
+ (U.Unbox a, U.Unbox b, U.Unbox c)
+ => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
+rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ (\a b -> unXArray (f (XArray a) (XArray b)))
+ arr1 arr2)
+ where
+ unXArray (XArray a) = a
+
+-- | The list argument gives indices into the original dimension list.
+transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transpose perm (XArray arr)
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.transpose perm arr)
+
+transpose2 :: forall sh1 sh2 a.
+ StaticShapeX sh1 -> StaticShapeX sh2
+ -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
+transpose2 ssh1 ssh2 (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh2
+ , Refl <- lemRankApp ssh2 ssh1
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
+ , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))
+ , Refl <- lemRankAppComm ssh1 ssh2
+ , let n1 = ssxLength ssh1
+ = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+
+sumFull :: (U.Unbox a, Num a) => XArray sh a -> a
+sumFull (XArray arr) = U.sumA arr
+
+sumInner :: forall sh sh' a. (U.Unbox a, Num a)
+ => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
+sumInner ssh ssh'
+ | Refl <- lemAppNil @sh
+ = rerank ssh ssh' SZX (scalar . sumFull)
+
+sumOuter :: forall sh sh' a. (U.Unbox a, Num a)
+ => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
+sumOuter ssh ssh'
+ | Refl <- lemAppNil @sh
+ = sumInner ssh' ssh . transpose2 ssh ssh'
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
new file mode 100644
index 0000000..983a636
--- /dev/null
+++ b/src/Data/Array/Nested.hs
@@ -0,0 +1,40 @@
+{-# LANGUAGE ExplicitNamespaces #-}
+module Data.Array.Nested (
+ -- * Ranked arrays
+ Ranked,
+ IxR(..),
+ rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
+ -- ** Lifting orthotope operations to 'Ranked' arrays
+ rlift,
+
+ -- * Shaped arrays
+ Shaped,
+ IxS(..),
+ KnownShape(..), SShape(..),
+ sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
+ -- ** Lifting orthotope operations to 'Shaped' arrays
+ slift,
+
+ -- * Mixed arrays
+ Mixed,
+ IxX(..),
+ KnownShapeX(..), StaticShapeX(..),
+ mgenerate,
+
+ -- * Array elements
+ Elt(mshape, mindex, mindexPartial, mlift),
+ Primitive(..),
+
+ -- * Natural numbers
+ module Data.Nat,
+
+ -- * Further utilities / re-exports
+ type (++),
+ VU.Unbox,
+) where
+
+import qualified Data.Vector.Unboxed as VU
+
+import Data.Array.Mixed
+import Data.Array.Nested.Internal
+import Data.Nat
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
new file mode 100644
index 0000000..1139c57
--- /dev/null
+++ b/src/Data/Array/Nested/Internal.hs
@@ -0,0 +1,623 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+{-|
+TODO:
+* This module needs better structure with an Internal module and less public
+ exports etc.
+
+* We should be more consistent in whether functions take a 'StaticShapeX'
+ argument or a 'KnownShapeX' constraint.
+
+-}
+
+module Data.Array.Nested.Internal where
+
+import Control.Monad (forM_)
+import Control.Monad.ST
+import Data.Coerce (coerce, Coercible)
+import Data.Kind
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Unboxed as VU
+import qualified Data.Vector.Unboxed.Mutable as VUM
+
+import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
+import qualified Data.Array.Mixed as X
+import Data.Nat
+
+
+type family Replicate n a where
+ Replicate Z a = '[]
+ Replicate (S n) a = a : Replicate n a
+
+type family MapJust l where
+ MapJust '[] = '[]
+ MapJust (x : xs) = Just x : MapJust xs
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
+ where
+ go :: SNat m -> StaticShapeX (Replicate m Nothing)
+ go SZ = SZX
+ go (SS n) = () :$? go n
+
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (knownNat @n)
+ where
+ go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go SZ = Refl
+ go (SS n) | Refl <- go n = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (knownNat @n)
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS n) | Refl <- go n = Refl
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'transpose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+
+newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
+
+newtype instance Mixed sh Int = M_Int (XArray sh Int)
+newtype instance Mixed sh Double = M_Double (XArray sh Double)
+newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
+-- etc.
+
+newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
+
+
+-- | Internal helper data family mirrorring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
+
+newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
+newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
+newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
+
+
+-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
+ mindex :: Mixed sh a -> IxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- ====== PRIVATE METHODS ====== --
+ -- Remember I said that this module needed better management of exports?
+
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IxX sh -> Mixed sh a
+
+ -- | Return the size of the individual (SoA) arrays in this value. If @a@
+ -- does not contain tuples, this coincides with the total number of scalars
+ -- in the given value; if @a@ contains tuples, then it is some multiple of
+ -- this number of scalars.
+ mvecsNumElts :: a -> Int
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents. The shape must not have size
+ -- zero; an error may be thrown otherwise.
+ mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance VU.Unbox a => Elt (Primitive a) where
+ mshape (M_Primitive a) = X.shape a
+ mindex (M_Primitive a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
+
+ mlift :: forall sh1 sh2.
+ (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift f (M_Primitive a)
+ | Refl <- X.lemAppNil @sh1
+ , Refl <- X.lemAppNil @sh2
+ = M_Primitive (f Proxy a)
+
+ memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
+ mvecsNumElts _ = 1
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh)
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a)
+ => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
+ let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
+ VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
+
+-- What a blessing that orthotope's Array has "representational" role on the value type!
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Double instance Elt Double
+deriving via Primitive () instance Elt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
+
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
+ mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
+ mshape (M_Nest arr)
+ | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
+ = ixAppPrefix (knownShapeX @sh) (mshape arr)
+ where
+ ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
+ ixAppPrefix SZX _ = IZX
+ ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
+ ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
+
+ mindex (M_Nest arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest arr) i
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift f (M_Nest arr)
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ = M_Nest (mlift f' arr)
+ where
+ f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
+ f' _
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
+ , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
+ , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3))
+ = f (Proxy @(sh' ++ sh3))
+
+ memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
+
+ mvecsNumElts arr =
+ let n = X.shapeSize (mshape arr)
+ in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
+
+ mvecsUnsafeNew sh example
+ | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
+ (mindex example (X.zeroIdx (knownShapeX @sh')))
+ where
+ sh' = mshape example
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
+ => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
+
+
+-- Public method. Turns out this doesn't have to be in the type class!
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate sh f
+ -- TODO: Do we need this checkBounds check elsewhere as well?
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
+ -- We need to be very careful here to ensure that neither 'sh' nor
+ -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
+ | X.shapeSize sh == 0 = memptyArray sh
+ | otherwise =
+ let firstidx = X.zeroIdx' sh
+ firstelem = f (X.zeroIdx' sh)
+ in if mvecsNumElts firstelem == 0
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this feels inefficient. Should improve this.
+ forM_ (tail (X.enumShape sh)) $ \idx ->
+ mvecsWrite sh idx (f idx) vecs
+ mvecsFreeze sh vecs
+ where
+ checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+ checkBounds IZX SZX = True
+ checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
+ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a
+-- type-level natural that supports induction.
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
+
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance (KnownNat n, Elt a) => Elt (Ranked n a) where
+ mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
+ mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift f (M_Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift f arr
+
+ memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
+ memptyArray i
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArray i
+
+ mvecsNumElts (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsNumElts arr
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
+ => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+data SShape sh where
+ ShNil :: SShape '[]
+ ShCons :: SNat n -> SShape sh -> SShape (n : sh)
+deriving instance Show (SShape sh)
+
+-- | A statically-known shape of a shape-typed array.
+class KnownShape sh where knownShape :: SShape sh
+instance KnownShape '[] where knownShape = ShNil
+instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
+
+lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
+lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
+ where
+ go :: SShape sh' -> StaticShapeX (MapJust sh')
+ go ShNil = SZX
+ go (ShCons n sh) = n :$@ go sh
+
+lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustPlusApp _ _ = go (knownShape @sh1)
+ where
+ go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
+ go ShNil = Refl
+ go (ShCons _ sh) | Refl <- go sh = Refl
+
+instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
+ mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift f (M_Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift f arr
+
+ memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArray i
+
+ mvecsNumElts (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = mvecsNumElts arr
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
+ => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+
+-- Utility function to satisfy the type checker sometimes
+rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
+rewriteMixed Refl x = x
+
+
+-- ====== API OF RANKED ARRAYS ====== --
+
+-- | An index into a rank-typed array.
+type IxR :: Nat -> Type
+data IxR n where
+ IZR :: IxR Z
+ (:::) :: Int -> IxR n -> IxR (S n)
+
+ixCvtXR :: IxX sh -> IxR (X.Rank sh)
+ixCvtXR IZX = IZR
+ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx
+ixCvtXR (n ::? idx) = n ::: ixCvtXR idx
+
+ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
+ixCvtRX IZR = IZX
+ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
+
+
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n
+rshape (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemRankReplicate (Proxy @n)
+ = ixCvtXR (mshape arr)
+
+rindex :: Elt a => Ranked n a -> IxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
+
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
+ (ixCvtRX idx))
+
+rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate sh f
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemRankReplicate (Proxy @n)
+ = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
+
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift f (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n2)
+ = Ranked (mlift f arr)
+
+rsumOuter1 :: forall n a.
+ (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
+ => Ranked (S n) a -> Ranked n a
+rsumOuter1 (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked
+ . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a)
+ . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing))
+ . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
+ $ arr
+
+
+-- ====== API OF SHAPED ARRAYS ====== --
+
+-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
+type IxS :: [Nat] -> Type
+data IxS sh where
+ IZS :: IxS '[]
+ (::$) :: Int -> IxS sh -> IxS (n : sh)
+
+cvtSShapeIxS :: SShape sh -> IxS sh
+cvtSShapeIxS ShNil = IZS
+cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh
+
+ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
+ixCvtXS ShNil IZX = IZS
+ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx
+
+ixCvtSX :: IxS sh -> IxX (MapJust sh)
+ixCvtSX IZS = IZX
+ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
+
+
+sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
+sshape _ = cvtSShapeIxS (knownShape @sh)
+
+sindex :: Elt a => Shaped sh a -> IxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
+
+sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
+sindexPartial (Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
+ (ixCvtSX idx))
+
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
+sgenerate sh f
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
+
+slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift f (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh2)
+ = Shaped (mlift f arr)
+
+ssumOuter1 :: forall sh n a.
+ (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped
+ . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
+ . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
+ . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
+ $ arr
diff --git a/src/Data/Nat.hs b/src/Data/Nat.hs
new file mode 100644
index 0000000..5dacc8a
--- /dev/null
+++ b/src/Data/Nat.hs
@@ -0,0 +1,70 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Nat where
+
+import Data.Proxy
+import Numeric.Natural
+import qualified GHC.TypeLits as G
+
+
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
+
+-- | A peano natural number. Intended to be used at the type level.
+data Nat = Z | S Nat
+ deriving (Show)
+
+-- | Singleton for a 'Nat'.
+data SNat n where
+ SZ :: SNat Z
+ SS :: SNat n -> SNat (S n)
+deriving instance Show (SNat n)
+
+-- | A singleton 'SNat' corresponding to @n@.
+class KnownNat n where knownNat :: SNat n
+instance KnownNat Z where knownNat = SZ
+instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+
+-- | Convert a 'Nat' to a normal number.
+unNat :: Nat -> Natural
+unNat Z = 0
+unNat (S n) = 1 + unNat n
+
+-- | Convert an 'SNat' to a normal number.
+unSNat :: SNat n -> Natural
+unSNat SZ = 0
+unSNat (SS n) = 1 + unSNat n
+
+-- | A 'KnownNat' dictionary is just a singleton natural, so we can create
+-- evidence of 'KnownNat' given an 'SNat'.
+snatKnown :: SNat n -> Dict KnownNat n
+snatKnown SZ = Dict
+snatKnown (SS n) | Dict <- snatKnown n = Dict
+
+-- | Add two 'Nat's
+type family n + m where
+ Z + m = m
+ S n + m = S (n + m)
+
+-- | Convert a 'Nat' to a "GHC.TypeLits" 'G.Nat'.
+type family GNat n where
+ GNat Z = 0
+ GNat (S n) = 1 G.+ GNat n
+
+-- | If an inductive 'Nat' is known, then the corresponding "GHC.TypeLits"
+-- 'G.Nat' is also known.
+gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n)
+gknownNat (Proxy @n) = go (knownNat @n)
+ where
+ go :: SNat m -> Dict G.KnownNat (GNat m)
+ go SZ = Dict
+ go (SS n) | Dict <- go n = Dict