diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 22:20:57 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 22:20:57 +0200 |
commit | f0752d67cd188f438280e1f0c692dc1f5f14a190 (patch) | |
tree | 2dd05c13aef3b3c6384bfa091b14633bc86e65a4 /src/Data/Array/Nested/Internal.hs | |
parent | 19eab026f4f4c6a2d38ceb1fffa6062ba2637a46 (diff) |
Refactor Nested (modules, function names)
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 2054 |
1 files changed, 0 insertions, 2054 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs deleted file mode 100644 index 0870789..0000000 --- a/src/Data/Array/Nested/Internal.hs +++ /dev/null @@ -1,2054 +0,0 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} - -module Data.Array.Nested.Internal where - -import Prelude hiding (mappend) - -import Control.DeepSeq (NFData) -import Control.Monad (forM_, when) -import Control.Monad.ST -import qualified Data.Array.RankedS as S -import Data.Bifunctor (first) -import Data.Coerce (coerce, Coercible) -import Data.Foldable as Foldable (toList) -import Data.Functor.Const -import Data.Int -import Data.Kind -import Data.List.NonEmpty (NonEmpty(..)) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.C.Types (CInt(..)) -import Foreign.Storable (Storable) -import qualified GHC.Float (log1p, expm1, log1pexp, log1mexp) -import GHC.Generics (Generic) -import GHC.IsList (IsList) -import qualified GHC.IsList as IsList -import GHC.TypeLits -import qualified GHC.TypeNats as TypeNats -import Unsafe.Coerce - -import Data.Array.Mixed -import qualified Data.Array.Mixed as X -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Types - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - - --- Stupid things that the type checker should be able to figure out in-line, but can't - -subst1 :: forall f a b. a :~: b -> f a :~: f b -subst1 Refl = Refl - -subst2 :: forall f c a b. a :~: b -> f a c :~: f b c -subst2 Refl = Refl - -lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l -lemAppLeft _ Refl = Refl - -knownNatSucc :: KnownNat n => Dict KnownNat (n + 1) -knownNatSucc = Dict - - -lemKnownShX :: StaticShX sh -> Dict KnownShX sh -lemKnownShX ZKX = Dict -lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict -lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict - -ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) -ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n - -lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) -lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) - -lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate SZ = Refl -lemRankReplicate (SS (n :: SNat nm1)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 - , Refl <- lemRankReplicate n - = Refl - -lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a - -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp sn _ _ = go sn - where - go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 - , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) - -lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True -lemLeqPlus _ _ _ = Refl - -lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True -lemLeqSuccSucc _ _ = unsafeCoerce Refl - -lemDropLenApp :: Rank l1 <= Rank l2 - => Proxy l1 -> Proxy l2 -> Proxy rest - -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) -lemDropLenApp _ _ _ = unsafeCoerce Refl - -lemTakeLenApp :: Rank l1 <= Rank l2 - => Proxy l1 -> Proxy l2 -> Proxy rest - -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) -lemTakeLenApp _ _ _ = unsafeCoerce Refl - -srankSh :: ShX sh f -> SNat (Rank sh) -srankSh ZSX = SNat -srankSh (_ :$% sh) | SNat <- srankSh sh = SNat - - --- === NEW INDEX TYPES === -- - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where - showsPrec _ = showListR shows - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -unconsListR :: ListR n1 i -> Maybe (UnconsListRRes i n1) -unconsListR (i ::: sh') = Just (UnconsListRRes sh' i) -unconsListR ZR = Nothing - -showListR :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS -showListR f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR sh' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (unconsListR -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = showListR shows l - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (unconsListR -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: (ShR sh) = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = showListR shows l - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList = go (SNat @n) - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error "IsList(ListR): Mismatched list length" - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where - type Item (IxR n i) = i - fromList = IxR . IsList.fromList - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList - - -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) -infixr 3 ::$ - -instance (forall n. Show (f n)) => Show (ListS sh f) where - showsPrec _ = showListS shows - -data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -unconsListS :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -unconsListS (x ::$ sh') = Just (UnconsListSRes sh' x) -unconsListS ZS = Nothing - -fmapListS :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -fmapListS _ ZS = ZS -fmapListS f (x ::$ xs) = f x ::$ fmapListS f xs - -foldListS :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -foldListS _ ZS = mempty -foldListS f (x ::$ xs) = f x <> foldListS f xs - -showListS :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS -showListS f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListS sh' f -> ShowS - go _ ZS = id - go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listSToList :: ListS sh (Const i) -> [i] -listSToList ZS = [] -listSToList (Const i ::$ is) = i : listSToList is - - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) - :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) - => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (unconsListS -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) -infixr 3 :.$ - -{-# COMPLETE ZIS, (:.$) #-} - -type IIxS sh = IxS sh Int - -instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = showListS (\(Const i) -> shows i) l - -instance Functor (IxS sh) where - fmap f (IxS l) = IxS (fmapListS (Const . f . getConst) l) - -instance Foldable (IxS sh) where - foldMap f (IxS l) = foldListS (f . getConst) l - - --- | The shape of a shape-typed array given as a list of 'SNat' values. -type role ShS nominal -type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord) - -pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS - -pattern (:$$) - :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) - => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (unconsListS -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) - -infixr 3 :$$ - -{-# COMPLETE ZSS, (:$$) #-} - -instance Show (ShS sh) where - showsPrec _ (ShS l) = showListS (shows . fromSNat) l - -lengthShS :: ShS sh -> Int -lengthShS (ShS l) = getSum (foldListS (\_ -> Sum 1) l) - -shSToList :: ShS sh -> [Int] -shSToList ZSS = [] -shSToList (sn :$$ sh) = fromSNat' sn : shSToList sh - - --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (lengthShS (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listSToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShS sh => IsList (IxS sh i) where - type Item (IxS sh i) = i - fromList = IxS . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and values are checked at runtime. -instance KnownShS sh => IsList (ShS sh) where - type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (lengthShS (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = shSToList - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class Storable a => PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Show, Eq, Generic) - --- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. -deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) - -instance NFData a => NFData (Mixed sh (Primitive a)) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector) --- etc. - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) -deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) -deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) -deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) -deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) -deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) -deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - --- | Tree giving the shape of every array component. -type family ShapeTree a where - ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) - ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - -- to avoid having to list all of the primitive types: - ShapeTree _ = () - - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - - mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - - -- ====== PRIVATE METHODS ====== -- - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate' / 'rgenerate' / --- 'sgenerate'. -class Elt a => KnownElt a where - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IShX sh -> Mixed sh a - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive sh _) = sh - mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift ssh2 f (M_Primitive _ a) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , let result = f ZKX a - = M_Primitive (X.shape ssh2 result) result - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , Refl <- lemAppNil @sh3 - , let result = f ZKX a b - = M_Primitive (X.shape ssh3 result) result - - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - - mtranspose perm (M_Primitive sh arr) = - M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) - - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where - memptyArray sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) - mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) - mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) - mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - - mcast ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) - - mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) - - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - - mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest sh arr) i - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest ZSX - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - - mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) - in M_Nest sh2 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) - in M_Nest sh3 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcast ssh1 sh2 _ (M_Nest sh1T arr) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh (Mixed sh' a) - -> Mixed (PermutePrefix is sh) (Mixed sh' a) - mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') - , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) - , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') - , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (shxPermutePrefix perm sh) - (mtranspose perm arr) - - mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - - mvecsUnsafeNew sh example - | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of - [] -> memptyArray sh - firstidx : restidxs -> - let firstelem = f (ixxZero' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = - let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 - where - sn :$% sh = mshape arr1 - sm :$% _ = mshape arr2 - ssh = ssxFromShape sh - snm :: SMayNat () SNat (AddMaybe n m) - snm = case (sn, sm) of - (SUnknown{}, _) -> SUnknown () - (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (snatPlus n m) - - f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b - f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a - => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') - (\(sshT :: StaticShX shT) -> - case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of - Refl -> X.replicate sh (ssxAppend ssh' sshT)) - arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = - let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') - arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - -masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -masXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -masXArrayPrim = masXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: PrimElt a - => (a -> a) - -> Mixed sh a -> Mixed sh a -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: PrimElt a - => (a -> a -> a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = - fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr)) - -mliftNumElt2 :: PrimElt a - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) - | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2)) - | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 numEltAdd - (-) = mliftNumElt2 numEltSub - (*) = mliftNumElt2 numEltMul - negate = mliftNumElt1 numEltNeg - abs = mliftNumElt1 numEltAbs - signum = mliftNumElt1 numEltSignum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 floatEltRecip - (/) = mliftNumElt2 floatEltDiv - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 floatEltExp - log = mliftNumElt1 floatEltLog - sqrt = mliftNumElt1 floatEltSqrt - - (**) = mliftNumElt2 floatEltPow - logBase = mliftNumElt2 floatEltLogbase - - sin = mliftNumElt1 floatEltSin - cos = mliftNumElt1 floatEltCos - tan = mliftNumElt1 floatEltTan - asin = mliftNumElt1 floatEltAsin - acos = mliftNumElt1 floatEltAcos - atan = mliftNumElt1 floatEltAtan - sinh = mliftNumElt1 floatEltSinh - cosh = mliftNumElt1 floatEltCosh - tanh = mliftNumElt1 floatEltTanh - asinh = mliftNumElt1 floatEltAsinh - acosh = mliftNumElt1 floatEltAcosh - atanh = mliftNumElt1 floatEltAtanh - log1p = mliftNumElt1 floatEltLog1p - expm1 = mliftNumElt1 floatEltExpm1 - log1pexp = mliftNumElt1 floatEltLog1pexp - log1mexp = mliftNumElt1 floatEltLog1mexp - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) - , Refl <- lemRankReplicate (srankSh (mshape arr)) - = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(MapJust sh') - , Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a) -deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a) - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a) -deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) -deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) - - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where - mshape (M_Ranked arr) = mshape arr - mindex (M_Ranked arr) i = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i = - coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - - mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoListOuter (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift ssh2 f (M_Ranked arr) = - coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = - coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 - - mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) - - mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where - memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (SNat @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - --- sshapeKnown :: ShS sh -> Dict KnownShape sh --- sshapeKnown ZSS = Dict --- sshapeKnown (SNat :$$ sh) | Dict <- sshapeKnown sh = Dict - -lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemCommMapJustApp ZSS _ = Refl -lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl - -instance Elt a => Elt (Shaped sh a) where - mshape (M_Shaped arr) = mshape arr - mindex (M_Shaped arr) i = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - - mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoListOuter (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift ssh2 f (M_Shaped arr) = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = - coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 ssh3 f arr1 arr2 - - mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) - - mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShS :: [Nat] -> Constraint -class KnownShS sh where knownShS :: ShS sh -instance KnownShS '[] where knownShS = ZSS -instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where - memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - --- ====== API OF RANKED ARRAYS ====== -- - -arithPromoteRanked :: forall n a. PrimElt a - => (forall sh. Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -arithPromoteRanked = coerce - -arithPromoteRanked2 :: forall n a. PrimElt a - => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -> Ranked n a -arithPromoteRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where - (+) = arithPromoteRanked2 (+) - (-) = arithPromoteRanked2 (-) - (*) = arithPromoteRanked2 (*) - negate = arithPromoteRanked negate - abs = arithPromoteRanked abs - signum = arithPromoteRanked signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" - recip = arithPromoteRanked recip - (/) = arithPromoteRanked2 (/) - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" - exp = arithPromoteRanked exp - log = arithPromoteRanked log - sqrt = arithPromoteRanked sqrt - (**) = arithPromoteRanked2 (**) - logBase = arithPromoteRanked2 logBase - sin = arithPromoteRanked sin - cos = arithPromoteRanked cos - tan = arithPromoteRanked tan - asin = arithPromoteRanked asin - acos = arithPromoteRanked acos - atan = arithPromoteRanked atan - sinh = arithPromoteRanked sinh - cosh = arithPromoteRanked cosh - tanh = arithPromoteRanked tanh - asinh = arithPromoteRanked asinh - acosh = arithPromoteRanked acosh - atanh = arithPromoteRanked atanh - log1p = arithPromoteRanked GHC.Float.log1p - expm1 = arithPromoteRanked GHC.Float.expm1 - log1pexp = arithPromoteRanked GHC.Float.log1pexp - log1mexp = arithPromoteRanked GHC.Float.log1mexp - -zeroIxR :: SNat n -> IIxR n -zeroIxR SZ = ZIR -zeroIxR (SS n) = 0 :.: zeroIxR n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerce Refl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerce Refl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerce Refl - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) - -shapeSizeR :: IShR n -> Int -shapeSizeR ZSR = 1 -shapeSizeR (n :$: sh) = n * shapeSizeR sh - - -rshape :: forall n a. Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -snatFromListR :: ListR n i -> SNat n -snatFromListR ZR = SNat -snatFromListR (_ ::: (l :: ListR n i)) | SNat <- snatFromListR l, Dict <- knownNatSucc @n = SNat - -snatFromIxR :: IxR n i -> SNat n -snatFromIxR (IxR sh) = snatFromListR sh - -snatFromShR :: ShR n i -> SNat n -snatFromShR (ShR sh) = snatFromListR sh - -rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (castWith (subst2 (lemReplicatePlusApp (snatFromIxR idx) (Proxy @m) (Proxy @Nothing))) arr) - (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f - | sn@SNat <- snatFromShR sh - , Dict <- lemKnownReplicate sn - , Refl <- lemRankReplicate sn - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. Elt a - => SNat n2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) - --- | See the documentation of 'mlift2'. -rlift2 :: forall n1 n2 n3 a. Elt a - => SNat n3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a -rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) - -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) - -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive - -applyPermR :: forall i n. [Int] -> ListR n i -> ListR n i -applyPermR = \perm sh -> - listrFromList perm $ \sperm -> - case (snatFromListR sperm, snatFromListR sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TypeNats.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "applyPermR: Index in permutation out of range" - -applyPermIxR :: forall n i. [Int] -> IxR n i -> IxR n i -applyPermIxR = coerce (applyPermR @i) - -applyPermShR :: forall n i. [Int] -> ShR n i -> ShR n i -applyPermShR = coerce (applyPermR @i) - -rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a -rtranspose perm arr - | sn@SNat <- snatFromShR (rshape arr) - , Dict <- lemKnownReplicate sn - , length perm <= fromIntegral (natVal (Proxy @n)) - = rlift sn - (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) - arr - | otherwise - = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" - -rappend :: forall n a. Elt a - => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend arr1 arr2 - | sn@SNat <- snatFromShR (rshape arr1) - , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) - arr1 arr2 - -rscalar :: Elt a => a -> Ranked 0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v - | Dict <- lemKnownReplicate (snatFromShR sh) - = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = rfromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: PrimElt a => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - -rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) - -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) - -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) - -rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a -rfromOrthotope sn arr - | Refl <- lemRankReplicate sn - = let xarr = XArray arr - in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) - -runScalar :: Elt a => Ranked 0 a -> a -runScalar arr = rindex arr ZIR - -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) - --- | If there is a zero-sized dimension in the @n@-prefix of the shape of the --- input array, then there is no way to deduce the full shape of the output --- array (more precisely, the @n2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the @n2@ part of the output shape with zeros. --- --- For example, if: --- --- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] --- f :: Ranked 2 Int -> Ranked 3 Float --- @ --- --- then: --- --- @ --- rrerank _ _ _ f arr :: Ranked 5 Float --- @ --- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank ssh sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr - -rreplicate :: forall n m a. Elt a - => IShR n -> Ranked m a -> Ranked (n + m) a -rreplicate sh (Ranked arr) - | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked (mreplicate (shCvtRX sh) arr) - -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x - | Dict <- lemKnownReplicate (snatFromShR sh) - = Ranked (mreplicateScalP (shCvtRX sh) x) - -rreplicateScal :: forall n a. PrimElt a - => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) - -rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (snatFromShR (rshape arr)) - (\_ -> X.sliceU i n) - arr - -rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (snatFromShR (rshape arr)) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr - -rreshape :: forall n n' a. Elt a - => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' rarr@(Ranked arr) - | Dict <- lemKnownReplicate (snatFromShR (rshape rarr)) - , Dict <- lemKnownReplicate (snatFromShR sh') - = Ranked (mreshape (shCvtRX sh') arr) - -riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a -riota n = TypeNats.withSomeSNat (fromIntegral n) $ mtoRanked . miota - -rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) - -rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) - -rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) - , Refl <- lemRankMapJust targetsh - = mcastToShaped arr targetsh - -rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a -rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) - -rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) -rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - - --- ====== API OF SHAPED ARRAYS ====== -- - -arithPromoteShaped :: forall sh a. PrimElt a - => (forall shx. Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -arithPromoteShaped = coerce - -arithPromoteShaped2 :: forall sh a. PrimElt a - => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -> Shaped sh a -arithPromoteShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where - (+) = arithPromoteShaped2 (+) - (-) = arithPromoteShaped2 (-) - (*) = arithPromoteShaped2 (*) - negate = arithPromoteShaped negate - abs = arithPromoteShaped abs - signum = arithPromoteShaped signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" - recip = arithPromoteShaped recip - (/) = arithPromoteShaped2 (/) - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" - exp = arithPromoteShaped exp - log = arithPromoteShaped log - sqrt = arithPromoteShaped sqrt - (**) = arithPromoteShaped2 (**) - logBase = arithPromoteShaped2 logBase - sin = arithPromoteShaped sin - cos = arithPromoteShaped cos - tan = arithPromoteShaped tan - asin = arithPromoteShaped asin - acos = arithPromoteShaped acos - atan = arithPromoteShaped atan - sinh = arithPromoteShaped sinh - cosh = arithPromoteShaped cosh - tanh = arithPromoteShaped tanh - asinh = arithPromoteShaped asinh - acosh = arithPromoteShaped acosh - atanh = arithPromoteShaped atanh - log1p = arithPromoteShaped GHC.Float.log1p - expm1 = arithPromoteShaped GHC.Float.expm1 - log1pexp = arithPromoteShaped GHC.Float.log1pexp - log1mexp = arithPromoteShaped GHC.Float.log1mexp - -zeroIxS :: ShS sh -> IIxS sh -zeroIxS ZSS = ZIS -zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -type family Tail l where - Tail (_ : xs) = xs - -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerce Refl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerce Refl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerce Refl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - -shapeSizeS :: ShS sh -> Int -shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh - - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemCommMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a - => ShS sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a - => ShS sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemCommMapJustTakeLen PNil _ = Refl -lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl -lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemCommMapJustDropLen PNil _ = Refl -lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl -lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" - -lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemCommMapJustIndex SZ (_ :$$ _) = Refl -lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemCommMapJustIndex i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemCommMapJustIndex _ ZSS = error "Index of empty" - -lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemCommMapJustPermute PNil _ = Refl -lemCommMapJustPermute (i `PCons` is) sh - | Refl <- lemCommMapJustPermute is sh - , Refl <- lemCommMapJustIndex i sh - = Refl - -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f -listsAppend ZS idx' = idx' -listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' - -listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f -listsTakeLen PNil _ = ZS -listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh -listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLen PNil sh = sh -listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh -listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) - -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f -listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh rest -listsIndex _ _ _ ZS _ = error "Index into empty shape" - -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLen @SNat) - -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) - -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT) -shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) - -applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f -applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) - -applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -applyPermIxS = coerce (applyPermS @(Const i)) - -applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -applyPermShS = coerce (applyPermS @SNat) - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) - => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - , Refl <- lemCommMapJustTakeLen perm (sshape sarr) - , Refl <- lemCommMapJustDropLen perm (sshape sarr) - , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) - = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 - -sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemCommMapJustApp sh (Proxy @sh1) - , Refl <- lemCommMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) - | Refl <- lemCommMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = - let _ :$$ sh = sshape arr - in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - -sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) - -sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mtoRanked arr - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) |