diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 23:44:26 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 23:45:52 +0200 |
commit | cb8a20c32e2737c28fa2993fb29ede9c0faa000d (patch) | |
tree | 97abc964008e96f51cc6e2cfc5f60340406c3d9b /src/Data | |
parent | 5f1213fc9e464ec361e6543884968980dd28457d (diff) |
Move casts to DAN.Convert; split Ranked/Shaped types into .Base
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 50 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 7 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 245 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 242 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 235 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 235 |
6 files changed, 537 insertions, 477 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index e9bc20e..cdd2b6d 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -10,17 +10,63 @@ module Data.Array.Nested.Convert where import Control.Category import Data.Proxy import Data.Type.Equality +import GHC.TypeLits (Nat) import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shaped +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked arr + | Refl <- lemRankReplicate (shxRank (mshape arr)) + = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) + where + convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) + convSh ZSX = ZSX + convSh (smn :$% (sh :: IShX sh'T)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) + = SUnknown (fromSMayNat' smn) :$% convSh sh + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh + | Refl <- lemRankMapJust targetsh + = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr + stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index c18db63..efbcf19 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -924,10 +924,3 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) -> Mixed sh a -> Mixed sh b -> Mixed sh c mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) - => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e2074ac..c0e1302 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -1,40 +1,28 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Ranked where +module Data.Array.Nested.Ranked ( + module Data.Array.Nested.Ranked.Base, + module Data.Array.Nested.Ranked, +) where import Prelude hiding (mappend, mconcat) -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST import Data.Array.RankedS qualified as S import Data.Bifunctor (first) import Data.Coerce (coerce) -import Data.Foldable (toList) -import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) import GHC.TypeLits import GHC.TypeNats qualified as TN @@ -43,217 +31,17 @@ import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types import Data.Array.XArray (XArray(..)) import Data.Array.XArray qualified as X +import Data.Array.Nested.Convert import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Strided.Arith --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) - -instance (Show a, Elt a) => Show (Ranked n a) where - showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) - in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr - -instance Elt a => NFData (Ranked n a) where - rnf (Ranked arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where - mshape (M_Ranked arr) = mshape arr - mindex (M_Ranked arr) i = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i = - coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - - mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoListOuter (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift ssh2 f (M_Ranked arr) = - coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = - coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) - @(NonEmpty (Mixed sh2 (Ranked n a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - - mconcat l = M_Ranked (mconcat (coerce l)) - - mrnf (M_Ranked arr) = mrnf arr - - type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Ranked arr) = marrayStrides arr - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where - memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArrayUnsafe i - | Dict <- lemKnownReplicate (SNat @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - -liftRanked1 :: forall n a b. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) - -> Ranked n a -> Ranked n b -liftRanked1 = coerce - -liftRanked2 :: forall n a b c. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) - -> Ranked n a -> Ranked n b -> Ranked n c -liftRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where - (+) = liftRanked2 (+) - (-) = liftRanked2 (-) - (*) = liftRanked2 (*) - negate = liftRanked1 negate - abs = liftRanked1 abs - signum = liftRanked1 signum - fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" - recip = liftRanked1 recip - (/) = liftRanked2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" - exp = liftRanked1 exp - log = liftRanked1 log - sqrt = liftRanked1 sqrt - (**) = liftRanked2 (**) - logBase = liftRanked2 logBase - sin = liftRanked1 sin - cos = liftRanked1 cos - tan = liftRanked1 tan - asin = liftRanked1 asin - acos = liftRanked1 acos - atan = liftRanked1 atan - sinh = liftRanked1 sinh - cosh = liftRanked1 cosh - tanh = liftRanked1 tanh - asinh = liftRanked1 asinh - acosh = liftRanked1 acosh - atanh = liftRanked1 atanh - log1p = liftRanked1 GHC.Float.log1p - expm1 = liftRanked1 GHC.Float.expm1 - log1pexp = liftRanked1 GHC.Float.log1pexp - log1mexp = liftRanked1 GHC.Float.log1mexp - -rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -rquotArray = liftRanked2 mquotArray -rremArray = liftRanked2 mremArray - -ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -ratan2Array = liftRanked2 matan2Array - - remptyArray :: KnownElt a => Ranked 1 a remptyArray = mtoRanked (memptyArray ZSX) -rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rrank :: Elt a => Ranked n a -> SNat n -rrank = shrRank . rshape - -- | The total number of elements in the array. rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape @@ -536,24 +324,3 @@ rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank rarr) - = mcast sshx arr diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs new file mode 100644 index 0000000..f827187 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Ranked.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Foldable (toList) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types +import Data.Array.XArray (XArray(..)) +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Strided.Arith + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) + +instance (Show a, Elt a) => Show (Ranked n a) where + showsPrec d arr@(Ranked marr) = + let sh = show (toList (rshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr + +instance Elt a => NFData (Ranked n a) where + rnf (Ranked arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) + deriving (Generic) + +deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) + @(NonEmpty (Mixed sh2 (Ranked n a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + + mconcat l = M_Ranked (mconcat (coerce l)) + + mrnf (M_Ranked arr) = mrnf arr + + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + + mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Ranked arr) = marrayStrides arr + + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArrayUnsafe i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +liftRanked1 :: forall n a b. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) + -> Ranked n a -> Ranked n b +liftRanked1 = coerce + +liftRanked2 :: forall n a b c. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) + -> Ranked n a -> Ranked n b -> Ranked n c +liftRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where + (+) = liftRanked2 (+) + (-) = liftRanked2 (-) + (*) = liftRanked2 (*) + negate = liftRanked1 negate + abs = liftRanked1 abs + signum = liftRanked1 signum + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + recip = liftRanked1 recip + (/) = liftRanked2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + exp = liftRanked1 exp + log = liftRanked1 log + sqrt = liftRanked1 sqrt + (**) = liftRanked2 (**) + logBase = liftRanked2 logBase + sin = liftRanked1 sin + cos = liftRanked1 cos + tan = liftRanked1 tan + asin = liftRanked1 asin + acos = liftRanked1 acos + atan = liftRanked1 atan + sinh = liftRanked1 sinh + cosh = liftRanked1 cosh + tanh = liftRanked1 tanh + asinh = liftRanked1 asinh + acosh = liftRanked1 acosh + atanh = liftRanked1 atanh + log1p = liftRanked1 GHC.Float.log1p + expm1 = liftRanked1 GHC.Float.expm1 + log1pexp = liftRanked1 GHC.Float.log1pexp + log1mexp = liftRanked1 GHC.Float.log1mexp + +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = liftRanked2 mquotArray +rremArray = liftRanked2 mremArray + +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = liftRanked2 matan2Array + + +rshape :: Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shCvtXR' (mshape arr) + +rrank :: Elt a => Ranked n a -> SNat n +rrank = shrRank . rshape diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 4bccbc4..c4c61fd 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -1,41 +1,29 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Shaped where +module Data.Array.Nested.Shaped ( + module Data.Array.Nested.Shaped.Base, + module Data.Array.Nested.Shaped, +) where import Prelude hiding (mappend, mconcat) -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST import Data.Array.Internal.RankedG qualified as RG import Data.Array.Internal.RankedS qualified as RS import Data.Array.Internal.ShapedG qualified as SG import Data.Array.Internal.ShapedS qualified as SS import Data.Bifunctor (first) import Data.Coerce (coerce) -import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) import GHC.TypeLits import Data.Array.Mixed.Lemmas @@ -44,211 +32,17 @@ import Data.Array.Mixed.Types import Data.Array.XArray (XArray) import Data.Array.XArray qualified as X import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Convert import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape import Data.Array.Strided.Arith --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) - -instance (Show a, Elt a) => Show (Shaped n a) where - showsPrec d arr@(Shaped marr) = - let sh = show (shsToList (sshape arr)) - in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr - -instance Elt a => NFData (Shaped sh a) where - rnf (Shaped arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) - -instance Elt a => Elt (Shaped sh a) where - mshape (M_Shaped arr) = mshape arr - mindex (M_Shaped arr) i = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - - mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoListOuter (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift ssh2 f (M_Shaped arr) = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = - coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) - @(NonEmpty (Mixed sh2 (Shaped sh a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - - mconcat l = M_Shaped (mconcat (coerce l)) - - mrnf (M_Shaped arr) = mrnf arr - - type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Shaped arr) = marrayStrides arr - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where - memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - -liftShaped1 :: forall sh a b. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) - -> Shaped sh a -> Shaped sh b -liftShaped1 = coerce - -liftShaped2 :: forall sh a b c. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) - -> Shaped sh a -> Shaped sh b -> Shaped sh c -liftShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where - (+) = liftShaped2 (+) - (-) = liftShaped2 (-) - (*) = liftShaped2 (*) - negate = liftShaped1 negate - abs = liftShaped1 abs - signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" - recip = liftShaped1 recip - (/) = liftShaped2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" - exp = liftShaped1 exp - log = liftShaped1 log - sqrt = liftShaped1 sqrt - (**) = liftShaped2 (**) - logBase = liftShaped2 logBase - sin = liftShaped1 sin - cos = liftShaped1 cos - tan = liftShaped1 tan - asin = liftShaped1 asin - acos = liftShaped1 acos - atan = liftShaped1 atan - sinh = liftShaped1 sinh - cosh = liftShaped1 cosh - tanh = liftShaped1 tanh - asinh = liftShaped1 asinh - acosh = liftShaped1 acosh - atanh = liftShaped1 atanh - log1p = liftShaped1 GHC.Float.log1p - expm1 = liftShaped1 GHC.Float.expm1 - log1pexp = liftShaped1 GHC.Float.log1pexp - log1mexp = liftShaped1 GHC.Float.log1mexp - -squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -squotArray = liftShaped2 mquotArray -sremArray = liftShaped2 mremArray - -satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -satan2Array = liftShaped2 matan2Array - - semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a semptyArray sh = Shaped (memptyArray (shCvtSX sh)) -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - srank :: Elt a => Shaped sh a -> SNat (Rank sh) srank = shsRank . sshape @@ -476,20 +270,3 @@ sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mcast sshx arr diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs new file mode 100644 index 0000000..74c231d --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -0,0 +1,235 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Shaped.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Types +import Data.Array.XArray (XArray) +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Strided.Arith + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shCvtXS' (mshape arr) |