aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-04-03 12:37:35 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-04-03 12:37:35 +0200
commit92902c4f66db111b439f3b7eba9de50ad7c73f7b (patch)
tree27f12853825b7dd13d4bc8040dd2be6781deb635 /src/Data/Array/Nested/Internal.hs
parent264c8e601f49cebed9280f0da2e73f380bb5be52 (diff)
Reorganise, documentation
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs623
1 files changed, 623 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
new file mode 100644
index 0000000..1139c57
--- /dev/null
+++ b/src/Data/Array/Nested/Internal.hs
@@ -0,0 +1,623 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+{-|
+TODO:
+* This module needs better structure with an Internal module and less public
+ exports etc.
+
+* We should be more consistent in whether functions take a 'StaticShapeX'
+ argument or a 'KnownShapeX' constraint.
+
+-}
+
+module Data.Array.Nested.Internal where
+
+import Control.Monad (forM_)
+import Control.Monad.ST
+import Data.Coerce (coerce, Coercible)
+import Data.Kind
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Unboxed as VU
+import qualified Data.Vector.Unboxed.Mutable as VUM
+
+import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
+import qualified Data.Array.Mixed as X
+import Data.Nat
+
+
+type family Replicate n a where
+ Replicate Z a = '[]
+ Replicate (S n) a = a : Replicate n a
+
+type family MapJust l where
+ MapJust '[] = '[]
+ MapJust (x : xs) = Just x : MapJust xs
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
+ where
+ go :: SNat m -> StaticShapeX (Replicate m Nothing)
+ go SZ = SZX
+ go (SS n) = () :$? go n
+
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (knownNat @n)
+ where
+ go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go SZ = Refl
+ go (SS n) | Refl <- go n = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (knownNat @n)
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS n) | Refl <- go n = Refl
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'transpose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+
+newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
+
+newtype instance Mixed sh Int = M_Int (XArray sh Int)
+newtype instance Mixed sh Double = M_Double (XArray sh Double)
+newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
+-- etc.
+
+newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
+
+
+-- | Internal helper data family mirrorring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
+
+newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
+newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
+newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
+
+
+-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
+ mindex :: Mixed sh a -> IxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- ====== PRIVATE METHODS ====== --
+ -- Remember I said that this module needed better management of exports?
+
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IxX sh -> Mixed sh a
+
+ -- | Return the size of the individual (SoA) arrays in this value. If @a@
+ -- does not contain tuples, this coincides with the total number of scalars
+ -- in the given value; if @a@ contains tuples, then it is some multiple of
+ -- this number of scalars.
+ mvecsNumElts :: a -> Int
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents. The shape must not have size
+ -- zero; an error may be thrown otherwise.
+ mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance VU.Unbox a => Elt (Primitive a) where
+ mshape (M_Primitive a) = X.shape a
+ mindex (M_Primitive a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
+
+ mlift :: forall sh1 sh2.
+ (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift f (M_Primitive a)
+ | Refl <- X.lemAppNil @sh1
+ , Refl <- X.lemAppNil @sh2
+ = M_Primitive (f Proxy a)
+
+ memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
+ mvecsNumElts _ = 1
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh)
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a)
+ => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
+ let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
+ VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
+
+-- What a blessing that orthotope's Array has "representational" role on the value type!
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Double instance Elt Double
+deriving via Primitive () instance Elt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
+
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
+ mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
+ mshape (M_Nest arr)
+ | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
+ = ixAppPrefix (knownShapeX @sh) (mshape arr)
+ where
+ ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
+ ixAppPrefix SZX _ = IZX
+ ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
+ ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
+
+ mindex (M_Nest arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest arr) i
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift f (M_Nest arr)
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ = M_Nest (mlift f' arr)
+ where
+ f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
+ f' _
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
+ , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
+ , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3))
+ = f (Proxy @(sh' ++ sh3))
+
+ memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
+
+ mvecsNumElts arr =
+ let n = X.shapeSize (mshape arr)
+ in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
+
+ mvecsUnsafeNew sh example
+ | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
+ (mindex example (X.zeroIdx (knownShapeX @sh')))
+ where
+ sh' = mshape example
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
+ => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
+
+
+-- Public method. Turns out this doesn't have to be in the type class!
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate sh f
+ -- TODO: Do we need this checkBounds check elsewhere as well?
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
+ -- We need to be very careful here to ensure that neither 'sh' nor
+ -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
+ | X.shapeSize sh == 0 = memptyArray sh
+ | otherwise =
+ let firstidx = X.zeroIdx' sh
+ firstelem = f (X.zeroIdx' sh)
+ in if mvecsNumElts firstelem == 0
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this feels inefficient. Should improve this.
+ forM_ (tail (X.enumShape sh)) $ \idx ->
+ mvecsWrite sh idx (f idx) vecs
+ mvecsFreeze sh vecs
+ where
+ checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+ checkBounds IZX SZX = True
+ checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
+ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a
+-- type-level natural that supports induction.
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
+
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance (KnownNat n, Elt a) => Elt (Ranked n a) where
+ mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
+ mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift f (M_Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift f arr
+
+ memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
+ memptyArray i
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArray i
+
+ mvecsNumElts (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsNumElts arr
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
+ => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+data SShape sh where
+ ShNil :: SShape '[]
+ ShCons :: SNat n -> SShape sh -> SShape (n : sh)
+deriving instance Show (SShape sh)
+
+-- | A statically-known shape of a shape-typed array.
+class KnownShape sh where knownShape :: SShape sh
+instance KnownShape '[] where knownShape = ShNil
+instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
+
+lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
+lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
+ where
+ go :: SShape sh' -> StaticShapeX (MapJust sh')
+ go ShNil = SZX
+ go (ShCons n sh) = n :$@ go sh
+
+lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustPlusApp _ _ = go (knownShape @sh1)
+ where
+ go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
+ go ShNil = Refl
+ go (ShCons _ sh) | Refl <- go sh = Refl
+
+instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
+ mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift f (M_Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift f arr
+
+ memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArray i
+
+ mvecsNumElts (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = mvecsNumElts arr
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
+ => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+
+-- Utility function to satisfy the type checker sometimes
+rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
+rewriteMixed Refl x = x
+
+
+-- ====== API OF RANKED ARRAYS ====== --
+
+-- | An index into a rank-typed array.
+type IxR :: Nat -> Type
+data IxR n where
+ IZR :: IxR Z
+ (:::) :: Int -> IxR n -> IxR (S n)
+
+ixCvtXR :: IxX sh -> IxR (X.Rank sh)
+ixCvtXR IZX = IZR
+ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx
+ixCvtXR (n ::? idx) = n ::: ixCvtXR idx
+
+ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
+ixCvtRX IZR = IZX
+ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
+
+
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n
+rshape (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemRankReplicate (Proxy @n)
+ = ixCvtXR (mshape arr)
+
+rindex :: Elt a => Ranked n a -> IxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
+
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
+ (ixCvtRX idx))
+
+rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate sh f
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemRankReplicate (Proxy @n)
+ = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
+
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift f (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n2)
+ = Ranked (mlift f arr)
+
+rsumOuter1 :: forall n a.
+ (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
+ => Ranked (S n) a -> Ranked n a
+rsumOuter1 (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked
+ . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a)
+ . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing))
+ . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
+ $ arr
+
+
+-- ====== API OF SHAPED ARRAYS ====== --
+
+-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
+type IxS :: [Nat] -> Type
+data IxS sh where
+ IZS :: IxS '[]
+ (::$) :: Int -> IxS sh -> IxS (n : sh)
+
+cvtSShapeIxS :: SShape sh -> IxS sh
+cvtSShapeIxS ShNil = IZS
+cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh
+
+ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
+ixCvtXS ShNil IZX = IZS
+ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx
+
+ixCvtSX :: IxS sh -> IxX (MapJust sh)
+ixCvtSX IZS = IZX
+ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
+
+
+sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
+sshape _ = cvtSShapeIxS (knownShape @sh)
+
+sindex :: Elt a => Shaped sh a -> IxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
+
+sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
+sindexPartial (Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
+ (ixCvtSX idx))
+
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
+sgenerate sh f
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
+
+slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift f (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh2)
+ = Shaped (mlift f arr)
+
+ssumOuter1 :: forall sh n a.
+ (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped
+ . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
+ . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
+ . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
+ $ arr