{-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-| TODO: * We should be more consistent in whether functions take a 'StaticShapeX' argument or a 'KnownShapeX' constraint. * Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point being that we need to do induction over the former, but the latter need to be able to get large. -} module Data.Array.Nested.Internal where import Prelude hiding (mappend) import Control.Monad (forM_) import Control.Monad.ST import qualified Data.Array.RankedS as S import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X import Data.INat type family Replicate n a where Replicate Z a = '[] Replicate (S n) a = a : Replicate n a type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) where go :: SINat m -> StaticShapeX (Replicate m Nothing) go SZ = SZX go (SS n) = () :$? go n lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (inatSing @n) where go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl go (SS n) | Refl <- go n = Refl lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a lemReplicatePlusApp _ _ _ = go (inatSing @n) where go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS n) | Refl <- go n = Refl -- | Wrapper type used as a tag to attach instances on. The instances on arrays -- of @'Primitive' a@ are more polymorphic than the direct instances for arrays -- of scalars; this means that if @orthotope@ supports an element type @T@ that -- this library does not (directly), it may just work if you use an array of -- @'Primitive' T@ instead. newtype Primitive a = Primitive a -- | Mixed arrays: some dimensions are size-typed, some are not. Distributes -- over product-typed elements using a data family so that the full array is -- always in struct-of-arrays format. -- -- Built on top of 'XArray' which is built on top of @orthotope@, meaning that -- dimension permutations (e.g. 'transpose') are typically free. -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) deriving (Show) newtype instance Mixed sh Int = M_Int (XArray sh Int) deriving (Show) newtype instance Mixed sh Double = M_Double (XArray sh Double) deriving (Show) newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) deriving (Show) -- etc. data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -- etc. newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -- | Internal helper data family mirrorring 'Mixed' that consists of mutable -- vectors instead of 'XArray's. type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type data family MixedVecs s sh a newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this -- etc. data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) -- etc. data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. class Elt a where -- ====== PUBLIC METHODS ====== -- mshape :: KnownShapeX sh => Mixed sh a -> IxX sh mindex :: Mixed sh a -> IxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a -- ====== PRIVATE METHODS ====== -- -- Remember I said that this module needed better management of exports? -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArray :: IxX sh -> Mixed sh a -- | Return the size of the individual (SoA) arrays in this value. If @a@ -- does not contain tuples, this coincides with the total number of scalars -- in the given value; if @a@ contains tuples, then it is some multiple of -- this number of scalars. mvecsNumElts :: a -> Int -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. The shape must not have size -- zero; an error may be thrown otherwise. mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) mlift :: forall sh1 sh2. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) mlift f (M_Primitive a) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 = M_Primitive (f Proxy a) mlift2 :: forall sh1 sh2 sh3. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) mlift2 f (M_Primitive a) (M_Primitive b) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 , Refl <- X.lemAppNil @sh3 = M_Primitive (f Proxy a b) memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) mvecsNumElts _ = 1 mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v deriving via Primitive Int instance Elt Int deriving via Primitive Double instance Elt Double deriving via Primitive () instance Elt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartial sh i x a mvecsWritePartial sh i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b -- Arrays of arrays are just arrays, but with more dimensions. instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = ixAppPrefix (knownShapeX @sh) (mshape arr) where ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 ixAppPrefix SZX _ = IZX ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx mindex (M_Nest arr) i = mindexPartial arr i mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest arr) i | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift f (M_Nest arr) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) = M_Nest (mlift f' arr) where f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' _ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) = f (Proxy @(sh' ++ shT)) mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 f (M_Nest arr1) (M_Nest arr2) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) = M_Nest (mlift2 f' arr1 arr2) where f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' _ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) = f (Proxy @(sh' ++ shT)) memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) mvecsNumElts arr = let n = X.shapeSize (mshape arr) in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) (mindex example (X.zeroIdx (knownShapeX @sh'))) where sh' = mshape example mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs -- Public method. Turns out this doesn't have to be in the type class! -- | Create an array given a size and a function that computes the element at a -- given index. mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -- We need to be very careful here to ensure that neither 'sh' nor -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. | X.shapeSize sh == 0 = memptyArray sh | otherwise = let firstidx = X.zeroIdx' sh firstelem = f (X.zeroIdx' sh) in if mvecsNumElts firstelem == 0 then memptyArray sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs -- TODO: This is likely fine if @a@ is big, but if @a@ is a -- scalar this feels inefficient. Should improve this. forM_ (tail (X.enumShape sh)) $ \idx -> mvecsWrite sh idx (f idx) vecs mvecsFreeze sh vecs where checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool checkBounds IZX SZX = True checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a mtranspose perm = mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') (X.transpose perm)) mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a mappend = mlift2 go where go :: forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) mliftPrim2 :: (KnownShapeX sh, Storable a) => (a -> a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where (+) = mliftPrim2 (+) (-) = mliftPrim2 (-) (*) = mliftPrim2 (*) negate = mliftPrim negate abs = mliftPrim abs signum = mliftPrim signum fromInteger n = case X.ssxToShape' (knownShapeX @sh) of Just sh -> M_Primitive (X.constant sh (fromInteger n)) Nothing -> error "Data.Array.Nested.fromIntegral: \ \Unknown components in shape, use explicit replicate" deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'INat'. -- -- Valid elements of a ranked arrays are described by the 'Elt' type class. -- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are -- supported (and are represented as a single, flattened, struct-of-arrays -- array internally). -- -- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a -- type-level natural that supports induction. -- -- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. type Ranked :: INat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -- | A shape-typed array: the full shape of the array (the sizes of its -- dimensions) is represented on the type level as a list of 'Nat's. Note that -- these are "GHC.TypeLits" naturals, because we do not need induction over -- them and we want very large arrays to be possible. -- -- Like for 'Ranked', the valid elements are described by the 'Elt' type class, -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -- just unwrap the newtype and defer to the general instance for nested arrays newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance (KnownINat n, Elt a) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) mlift f (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift f arr mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) mlift2 f (M_Ranked arr1) (M_Ranked arr2) | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 f arr1 arr2 memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ memptyArray i mvecsNumElts (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mvecsNumElts arr mvecsUnsafeNew idx (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsUnsafeNew idx arr mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWrite sh idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () mvecsWritePartial sh idx arr vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWritePartial sh idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) vecs) mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) mvecsFreeze sh vecs | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) <$> mvecsFreeze sh (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) -- | The shape of a shape-typed array given as a list of 'SNat' values. data SShape sh where ShNil :: SShape '[] ShCons :: SNat n -> SShape sh -> SShape (n : sh) deriving instance Show (SShape sh) infixr 5 `ShCons` -- | A statically-known shape of a shape-typed array. class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons natSing knownShape sshapeKnown :: SShape sh -> Dict KnownShape sh sshapeKnown ShNil = Dict sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) where go :: SShape sh' -> StaticShapeX (MapJust sh') go ShNil = SZX go (ShCons n sh) = n :$@ go sh lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 lemMapJustPlusApp _ _ = go (knownShape @sh1) where go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 go ShNil = Refl go (ShCons _ sh) | Refl <- go sh = Refl instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) mlift f (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift f arr mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) mlift2 f (M_Shaped arr1) (M_Shaped arr2) | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 f arr1 arr2 memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ memptyArray i mvecsNumElts (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mvecsNumElts arr mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWrite sh idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWritePartial sh idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) <$> mvecsFreeze sh (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) -- Utility functions to satisfy the type checker sometimes rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a rewriteMixed Refl x = x coerceMixedXArray :: Coercible (Mixed sh a) (XArray sh a) => XArray sh a -> Mixed sh a coerceMixedXArray = coerce -- ====== API OF RANKED ARRAYS ====== -- arithPromoteRanked :: forall n a. KnownINat n => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce arithPromoteRanked2 :: forall n a. KnownINat n => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n) deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) -- | An index into a rank-typed array. type IxR :: INat -> Type data IxR n where IZR :: IxR Z (:::) :: Int -> IxR n -> IxR (S n) infixr 5 ::: ixCvtXR :: IxX sh -> IxR (X.Rank sh) ixCvtXR IZX = IZR ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx ixCvtXR (n ::? idx) = n ::: ixCvtXR idx ixCvtRX :: IxR n -> IxX (Replicate n Nothing) ixCvtRX IZR = IZX ixCvtRX (n ::: idx) = n ::? ixCvtRX idx rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = ixCvtXR (mshape arr) rindex :: Elt a => Ranked n a -> IxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) (ixCvtRX idx)) rgenerate :: forall n a. (KnownINat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a rgenerate sh f | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) rlift :: forall n1 n2 a. (KnownINat n2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) = Ranked (mlift f arr) rsumOuter1 :: forall n a. (Storable a, Num a, KnownINat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) => Ranked (S n) a -> Ranked n a rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked . coerceMixedXArray . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) $ arr rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mtranspose perm arr) rappend :: forall n a. (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend -- ====== API OF SHAPED ARRAYS ====== -- arithPromoteShaped :: forall sh a. KnownShape sh => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) -> Shaped sh a -> Shaped sh a arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce arithPromoteShaped2 :: forall sh a. KnownShape sh => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) -> Shaped sh a -> Shaped sh a -> Shaped sh a arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n) deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). Note that because the shape of a -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. type IxS :: [Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) infixr 5 ::$ cvtSShapeIxS :: SShape sh -> IxS sh cvtSShapeIxS ShNil = IZS cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh ixCvtXS ShNil IZX = IZS ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx ixCvtSX :: IxS sh -> IxX (MapJust sh) ixCvtSX IZS = IZX ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh sshape _ = cvtSShapeIxS (knownShape @sh) sindex :: Elt a => Shaped sh a -> IxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a sgenerate sh f | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift f (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh2) = Shaped (mlift f arr) ssumOuter1 :: forall sh n a. (Storable a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerceMixedXArray . X.sumOuter (natSing @n :$@ SZX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) $ arr stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mtranspose perm arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend