{-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-| TODO: * This module needs better structure with an Internal module and less public exports etc. * We should be more consistent in whether functions take a 'StaticShapeX' argument or a 'KnownShapeX' constraint. -} module Data.Array.Nested.Internal where import Control.Monad (forM_) import Control.Monad.ST import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) import qualified Data.Array.Mixed as X import Data.Nat type family Replicate n a where Replicate Z a = '[] Replicate (S n) a = a : Replicate n a type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) where go :: SNat m -> StaticShapeX (Replicate m Nothing) go SZ = SZX go (SS n) = () :$? go n lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (knownNat @n) where go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl go (SS n) | Refl <- go n = Refl lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a lemReplicatePlusApp _ _ _ = go (knownNat @n) where go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS n) | Refl <- go n = Refl -- | Wrapper type used as a tag to attach instances on. The instances on arrays -- of @'Primitive' a@ are more polymorphic than the direct instances for arrays -- of scalars; this means that if @orthotope@ supports an element type @T@ that -- this library does not (directly), it may just work if you use an array of -- @'Primitive' T@ instead. newtype Primitive a = Primitive a -- | Mixed arrays: some dimensions are size-typed, some are not. Distributes -- over product-typed elements using a data family so that the full array is -- always in struct-of-arrays format. -- -- Built on top of 'XArray' which is built on top of @orthotope@, meaning that -- dimension permutations (e.g. 'transpose') are typically free. -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) deriving (Show) newtype instance Mixed sh Int = M_Int (XArray sh Int) deriving (Show) newtype instance Mixed sh Double = M_Double (XArray sh Double) deriving (Show) newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) deriving (Show) -- etc. data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -- etc. newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -- | Internal helper data family mirrorring 'Mixed' that consists of mutable -- vectors instead of 'XArray's. type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type data family MixedVecs s sh a newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this -- etc. data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) -- etc. data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. class Elt a where -- ====== PUBLIC METHODS ====== -- mshape :: KnownShapeX sh => Mixed sh a -> IxX sh mindex :: Mixed sh a -> IxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -- ====== PRIVATE METHODS ====== -- -- Remember I said that this module needed better management of exports? -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArray :: IxX sh -> Mixed sh a -- | Return the size of the individual (SoA) arrays in this value. If @a@ -- does not contain tuples, this coincides with the total number of scalars -- in the given value; if @a@ contains tuples, then it is some multiple of -- this number of scalars. mvecsNumElts :: a -> Int -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. The shape must not have size -- zero; an error may be thrown otherwise. mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- Arrays of scalars are basically just arrays of scalars. instance VU.Unbox a => Elt (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) mlift :: forall sh1 sh2. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) mlift f (M_Primitive a) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 = M_Primitive (f Proxy a) memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) mvecsNumElts _ = 1 mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v deriving via Primitive Int instance Elt Int deriving via Primitive Double instance Elt Double deriving via Primitive () instance Elt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartial sh i x a mvecsWritePartial sh i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b -- Arrays of arrays are just arrays, but with more dimensions. instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = ixAppPrefix (knownShapeX @sh) (mshape arr) where ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 ixAppPrefix SZX _ = IZX ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx mindex (M_Nest arr) i = mindexPartial arr i mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest arr) i | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift f (M_Nest arr) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) = M_Nest (mlift f' arr) where f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b f' _ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) = f (Proxy @(sh' ++ sh3)) memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) mvecsNumElts arr = let n = X.shapeSize (mshape arr) in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) (mindex example (X.zeroIdx (knownShapeX @sh'))) where sh' = mshape example mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs -- Public method. Turns out this doesn't have to be in the type class! -- | Create an array given a size and a function that computes the element at a -- given index. mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -- We need to be very careful here to ensure that neither 'sh' nor -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. | X.shapeSize sh == 0 = memptyArray sh | otherwise = let firstidx = X.zeroIdx' sh firstelem = f (X.zeroIdx' sh) in if mvecsNumElts firstelem == 0 then memptyArray sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs -- TODO: This is likely fine if @a@ is big, but if @a@ is a -- scalar this feels inefficient. Should improve this. forM_ (tail (X.enumShape sh)) $ \idx -> mvecsWrite sh idx (f idx) vecs mvecsFreeze sh vecs where checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool checkBounds IZX SZX = True checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a mtranspose perm = mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') (X.transpose perm)) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'Nat'. -- -- Valid elements of a ranked arrays are described by the 'Elt' type class. -- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are -- supported (and are represented as a single, flattened, struct-of-arrays -- array internally). -- -- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a -- type-level natural that supports induction. -- -- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -- | A shape-typed array: the full shape of the array (the sizes of its -- dimensions) is represented on the type level as a list of 'Nat's. -- -- Like for 'Ranked', the valid elements are described by the 'Elt' type class, -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -- just unwrap the newtype and defer to the general instance for nested arrays newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance (KnownNat n, Elt a) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) mlift f (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift f arr memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ memptyArray i mvecsNumElts (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mvecsNumElts arr mvecsUnsafeNew idx (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsUnsafeNew idx arr mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWrite sh idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () mvecsWritePartial sh idx arr vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWritePartial sh idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) vecs) mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) mvecsFreeze sh vecs | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) <$> mvecsFreeze sh (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) -- | The shape of a shape-typed array given as a list of 'SNat' values. data SShape sh where ShNil :: SShape '[] ShCons :: SNat n -> SShape sh -> SShape (n : sh) deriving instance Show (SShape sh) infixr 5 `ShCons` -- | A statically-known shape of a shape-typed array. class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape sshapeKnown :: SShape sh -> Dict KnownShape sh sshapeKnown ShNil = Dict sshapeKnown (ShCons n sh) | Dict <- snatKnown n, Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) where go :: SShape sh' -> StaticShapeX (MapJust sh') go ShNil = SZX go (ShCons n sh) = n :$@ go sh lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 lemMapJustPlusApp _ _ = go (knownShape @sh1) where go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 go ShNil = Refl go (ShCons _ sh) | Refl <- go sh = Refl instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) mlift f (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift f arr memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ memptyArray i mvecsNumElts (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mvecsNumElts arr mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWrite sh idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWritePartial sh idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) <$> mvecsFreeze sh (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) -- Utility function to satisfy the type checker sometimes rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a rewriteMixed Refl x = x -- ====== API OF RANKED ARRAYS ====== -- -- | An index into a rank-typed array. type IxR :: Nat -> Type data IxR n where IZR :: IxR Z (:::) :: Int -> IxR n -> IxR (S n) infixr 5 ::: ixCvtXR :: IxX sh -> IxR (X.Rank sh) ixCvtXR IZX = IZR ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx ixCvtXR (n ::? idx) = n ::: ixCvtXR idx ixCvtRX :: IxR n -> IxX (Replicate n Nothing) ixCvtRX IZR = IZX ixCvtRX (n ::: idx) = n ::? ixCvtRX idx rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = ixCvtXR (mshape arr) rindex :: Elt a => Ranked n a -> IxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) (ixCvtRX idx)) rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a rgenerate sh f | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) rlift :: forall n1 n2 a. (KnownNat n2, Elt a) => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) = Ranked (mlift f arr) rsumOuter1 :: forall n a. (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) => Ranked (S n) a -> Ranked n a rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a) . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) $ arr rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mtranspose perm arr) -- ====== API OF SHAPED ARRAYS ====== -- -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). Note that because the shape of a -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. type IxS :: [Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) infixr 5 ::$ cvtSShapeIxS :: SShape sh -> IxS sh cvtSShapeIxS ShNil = IZS cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh ixCvtXS ShNil IZX = IZS ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx ixCvtSX :: IxS sh -> IxX (MapJust sh) ixCvtSX IZS = IZX ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh sshape _ = cvtSShapeIxS (knownShape @sh) sindex :: Elt a => Shaped sh a -> IxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a sgenerate sh f | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift f (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh2) = Shaped (mlift f arr) ssumOuter1 :: forall sh n a. (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mtranspose perm arr)