diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-03 12:37:35 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-03 12:37:35 +0200 |
commit | 92902c4f66db111b439f3b7eba9de50ad7c73f7b (patch) | |
tree | 27f12853825b7dd13d4bc8040dd2be6781deb635 /src/Fancy.hs | |
parent | 264c8e601f49cebed9280f0da2e73f380bb5be52 (diff) |
Reorganise, documentation
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r-- | src/Fancy.hs | 598 |
1 files changed, 0 insertions, 598 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs deleted file mode 100644 index 7461c1f..0000000 --- a/src/Fancy.hs +++ /dev/null @@ -1,598 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} - -{-| -TODO: -* This module needs better structure with an Internal module and less public - exports etc. - -* We should be more consistent in whether functions take a 'StaticShapeX' - argument or a 'KnownShapeX' constraint. - --} - -module Fancy where - -import Control.Monad (forM_) -import Control.Monad.ST -import Data.Coerce (coerce, Coercible) -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU -import qualified Data.Vector.Unboxed.Mutable as VUM - -import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) -import qualified Array as X -import Nats - - -type family Replicate n a where - Replicate Z a = '[] - Replicate (S n) a = a : Replicate n a - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) - where - go :: SNat m -> StaticShapeX (Replicate m Nothing) - go SZ = SZX - go (SS n) = () :$? go n - -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (knownNat @n) - where - go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (knownNat @n) - where - go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS n) | Refl <- go n = Refl - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a dat afamily so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'transpose') are typically free. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a - -newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) - -newtype instance Mixed sh Int = M_Int (XArray sh Int) -newtype instance Mixed sh Double = M_Double (XArray sh Double) -newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) --- etc. - -newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) - - --- | Internal helper data family mirrorring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) - -newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) -newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) -newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) - - --- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'GMixed' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class GMixed a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: KnownShapeX sh => Mixed sh a -> IxX sh - mindex :: Mixed sh a -> IxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- ====== PRIVATE METHODS ====== -- - -- Remember I said that this module needed better management of exports? - - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IxX sh -> Mixed sh a - - -- | Return the size of the individual (SoA) arrays in this value. If @a@ - -- does not contain tuples, this coincides with the total number of scalars - -- in the given value; if @a@ contains tuples, then it is some multiple of - -- this number of scalars. - mvecsNumElts :: a -> Int - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. The shape must not have size - -- zero; an error may be thrown otherwise. - mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance VU.Unbox a => GMixed (Primitive a) where - mshape (M_Primitive a) = X.shape a - mindex (M_Primitive a) i = Primitive (X.index a i) - mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) - - mlift :: forall sh1 sh2. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift f (M_Primitive a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - = M_Primitive (f Proxy a) - - memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) - mvecsNumElts _ = 1 - mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) - => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do - let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) - VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v - --- What a blessing that orthotope's Array has "representational" role on the value type! -deriving via Primitive Int instance GMixed Int -deriving via Primitive Double instance GMixed Double -deriving via Primitive () instance GMixed () - --- Arrays of pairs are pairs of arrays. -instance (GMixed a, GMixed b) => GMixed (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) - - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - --- Arrays of arrays are just arrays, but with more dimensions. -instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh - mshape (M_Nest arr) - | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = ixAppPrefix (knownShapeX @sh) (mshape arr) - where - ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 - ixAppPrefix SZX _ = IZX - ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx - ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx - - mindex (M_Nest arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift f (M_Nest arr) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - = M_Nest (mlift f' arr) - where - f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b - f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) - = f (Proxy @(sh' ++ sh3)) - - memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) - - mvecsNumElts arr = - let n = X.shapeSize (mshape arr) - in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) - - mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) - (mindex example (X.zeroIdx (knownShapeX @sh'))) - where - sh' = mshape example - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs - - --- Public method. Turns out this doesn't have to be in the type class! --- | Create an array given a size and a function that computes the element at a --- given index. -mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a -mgenerate sh f - -- TODO: Do we need this checkBounds check elsewhere as well? - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - -- We need to be very careful here to ensure that neither 'sh' nor - -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. - | X.shapeSize sh == 0 = memptyArray sh - | otherwise = - let firstidx = X.zeroIdx' sh - firstelem = f (X.zeroIdx' sh) - in if mvecsNumElts firstelem == 0 - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this feels inefficient. Should improve this. - forM_ (tail (X.enumShape sh)) $ \idx -> - mvecsWrite sh idx (f idx) vecs - mvecsFreeze sh vecs - where - checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool - checkBounds IZX SZX = True - checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' - checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' - - --- | Newtype around a 'Mixed' of 'Nothing's. This works like a rank-typed array --- as in @orthotope@. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) - --- | Newtype around a 'Mixed' of 'Just's. This works like a shape-typed array --- as in @orthotope@. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) - - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where - mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr - mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift f (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift f arr - - memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i - - mvecsNumElts (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsNumElts arr - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - -data SShape sh where - ShNil :: SShape '[] - ShCons :: SNat n -> SShape sh -> SShape (n : sh) -deriving instance Show (SShape sh) - -class KnownShape sh where knownShape :: SShape sh -instance KnownShape '[] where knownShape = ShNil -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape - -lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) -lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) - where - go :: SShape sh' -> StaticShapeX (MapJust sh') - go ShNil = SZX - go (ShCons n sh) = n :$@ go sh - -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ShNil = Refl - go (ShCons _ sh) | Refl <- go sh = Refl - -instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where - mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr - mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift f (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift f arr - - memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i - - mvecsNumElts (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsNumElts arr - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - --- Utility function to satisfy the type checker sometimes -rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a -rewriteMixed Refl x = x - - --- ====== API OF RANKED ARRAYS ====== -- - --- | An index into a rank-typed array. -type IxR :: Nat -> Type -data IxR n where - IZR :: IxR Z - (:::) :: Int -> IxR n -> IxR (S n) - -ixCvtXR :: IxX sh -> IxR (X.Rank sh) -ixCvtXR IZX = IZR -ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx -ixCvtXR (n ::? idx) = n ::: ixCvtXR idx - -ixCvtRX :: IxR n -> IxX (Replicate n Nothing) -ixCvtRX IZR = IZX -ixCvtRX (n ::: idx) = n ::? ixCvtRX idx - - -rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n -rshape (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = ixCvtXR (mshape arr) - -rindex :: GMixed a => Ranked n a -> IxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) - (ixCvtRX idx)) - -rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a -rgenerate sh f - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) - -rlift :: forall n1 n2 a. (KnownNat n2, GMixed a) - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift f (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n2) - = Ranked (mlift f arr) - -rsumOuter1 :: forall n a. - (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) - => Ranked (S n) a -> Ranked n a -rsumOuter1 (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked - . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a) - . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) - $ arr - - --- ====== API OF SHAPED ARRAYS ====== -- - --- | An index into a shape-typed array. -type IxS :: [Nat] -> Type -data IxS sh where - IZS :: IxS '[] - (::$) :: Int -> IxS sh -> IxS (n : sh) - -cvtSShapeIxS :: SShape sh -> IxS sh -cvtSShapeIxS ShNil = IZS -cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh - -ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh -ixCvtXS ShNil IZX = IZS -ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx - -ixCvtSX :: IxS sh -> IxX (MapJust sh) -ixCvtSX IZS = IZX -ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh - - -sshape :: forall sh a. (KnownShape sh, GMixed a) => Shaped sh a -> IxS sh -sshape _ = cvtSShapeIxS (knownShape @sh) - -sindex :: GMixed a => Shaped sh a -> IxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, GMixed a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a -sindexPartial (Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) - (ixCvtSX idx)) - -sgenerate :: forall sh a. (KnownShape sh, GMixed a) => IxS sh -> (IxS sh -> a) -> Shaped sh a -sgenerate sh f - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) - -slift :: forall sh1 sh2 a. (KnownShape sh2, GMixed a) - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift f (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh2) - = Shaped (mlift f arr) - -ssumOuter1 :: forall sh n a. - (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped - . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) - . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) - . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) - $ arr |