aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-05-15 23:44:26 +0200
committerTom Smeding <tom@tomsmeding.com>2025-05-15 23:45:52 +0200
commitcb8a20c32e2737c28fa2993fb29ede9c0faa000d (patch)
tree97abc964008e96f51cc6e2cfc5f60340406c3d9b /src/Data/Array
parent5f1213fc9e464ec361e6543884968980dd28457d (diff)
Move casts to DAN.Convert; split Ranked/Shaped types into .Base
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Convert.hs50
-rw-r--r--src/Data/Array/Nested/Mixed.hs7
-rw-r--r--src/Data/Array/Nested/Ranked.hs245
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs242
-rw-r--r--src/Data/Array/Nested/Shaped.hs235
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs235
6 files changed, 537 insertions, 477 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index e9bc20e..cdd2b6d 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -10,17 +10,63 @@ module Data.Array.Nested.Convert where
import Control.Category
import Data.Proxy
import Data.Type.Equality
+import GHC.TypeLits (Nat)
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Ranked
-import Data.Array.Nested.Shaped
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked arr
+ | Refl <- lemRankReplicate (shxRank (mshape arr))
+ = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
+ where
+ convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
+ convSh ZSX = ZSX
+ convSh (smn :$% (sh :: IShX sh'T))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
+ = SUnknown (fromSMayNat' smn) :$% convSh sh
+
+rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
+rtoMixed (Ranked arr) = arr
+
+-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
+-- compatibility check.
+rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
+rcastToMixed sshx rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank rarr)
+ = mcast sshx arr
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => Mixed sh a -> ShS sh' -> Shaped sh' a
+mcastToShaped arr targetsh
+ | Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
+
+stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
+stoMixed (Shaped arr) = arr
+
+-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
+-- compatibility check.
+scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => StaticShX sh' -> Shaped sh a -> Mixed sh' a
+scastToMixed sshx sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mcast sshx arr
+
stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index c18db63..efbcf19 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -924,10 +924,3 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
-
-mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
- => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
-mcast ssh2 arr
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index e2074ac..c0e1302 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -1,40 +1,28 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Ranked where
+module Data.Array.Nested.Ranked (
+ module Data.Array.Nested.Ranked.Base,
+ module Data.Array.Nested.Ranked,
+) where
import Prelude hiding (mappend, mconcat)
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
import Data.Array.RankedS qualified as S
import Data.Bifunctor (first)
import Data.Coerce (coerce)
-import Data.Foldable (toList)
-import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
import GHC.TypeLits
import GHC.TypeNats qualified as TN
@@ -43,217 +31,17 @@ import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Types
import Data.Array.XArray (XArray(..))
import Data.Array.XArray qualified as X
+import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith
--- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'Nat'.
---
--- Valid elements of a ranked arrays are described by the 'Elt' type class.
--- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
--- supported (and are represented as a single, flattened, struct-of-arrays
--- array internally).
---
--- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: Nat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
-deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
-
-instance (Show a, Elt a) => Show (Ranked n a) where
- showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
- in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Ranked n a) where
- rnf (Ranked arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance Elt a => Elt (Ranked n a) where
- mshape (M_Ranked arr) = mshape arr
- mindex (M_Ranked arr) i = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i =
- coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
-
- mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoListOuter (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift ssh2 f (M_Ranked arr) =
- coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
- coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
- @(NonEmpty (Mixed sh2 (Ranked n a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
-
- mconcat l = M_Ranked (mconcat (coerce l))
-
- mrnf (M_Ranked arr) = mrnf arr
-
- type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
-
- mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Ranked arr) = marrayStrides arr
-
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
- memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
- | Dict <- lemKnownReplicate (SNat @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
-
-liftRanked1 :: forall n a b.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
- -> Ranked n a -> Ranked n b
-liftRanked1 = coerce
-
-liftRanked2 :: forall n a b c.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
- -> Ranked n a -> Ranked n b -> Ranked n c
-liftRanked2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Ranked n a) where
- (+) = liftRanked2 (+)
- (-) = liftRanked2 (-)
- (*) = liftRanked2 (*)
- negate = liftRanked1 negate
- abs = liftRanked1 abs
- signum = liftRanked1 signum
- fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
- recip = liftRanked1 recip
- (/) = liftRanked2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
- exp = liftRanked1 exp
- log = liftRanked1 log
- sqrt = liftRanked1 sqrt
- (**) = liftRanked2 (**)
- logBase = liftRanked2 logBase
- sin = liftRanked1 sin
- cos = liftRanked1 cos
- tan = liftRanked1 tan
- asin = liftRanked1 asin
- acos = liftRanked1 acos
- atan = liftRanked1 atan
- sinh = liftRanked1 sinh
- cosh = liftRanked1 cosh
- tanh = liftRanked1 tanh
- asinh = liftRanked1 asinh
- acosh = liftRanked1 acosh
- atanh = liftRanked1 atanh
- log1p = liftRanked1 GHC.Float.log1p
- expm1 = liftRanked1 GHC.Float.expm1
- log1pexp = liftRanked1 GHC.Float.log1pexp
- log1mexp = liftRanked1 GHC.Float.log1mexp
-
-rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-rquotArray = liftRanked2 mquotArray
-rremArray = liftRanked2 mremArray
-
-ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-ratan2Array = liftRanked2 matan2Array
-
-
remptyArray :: KnownElt a => Ranked 1 a
remptyArray = mtoRanked (memptyArray ZSX)
-rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shCvtXR' (mshape arr)
-
-rrank :: Elt a => Ranked n a -> SNat n
-rrank = shrRank . rshape
-
-- | The total number of elements in the array.
rsize :: Elt a => Ranked n a -> Int
rsize = shrSize . rshape
@@ -536,24 +324,3 @@ rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
-
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked arr
- | Refl <- lemRankReplicate (shxRank (mshape arr))
- = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
- where
- convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
- convSh ZSX = ZSX
- convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
- = SUnknown (fromSMayNat' smn) :$% convSh sh
-
-rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
-rtoMixed (Ranked arr) = arr
-
--- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
--- compatibility check.
-rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
-rcastToMixed sshx rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank rarr)
- = mcast sshx arr
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
new file mode 100644
index 0000000..f827187
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -0,0 +1,242 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Ranked.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Foldable (toList)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Types
+import Data.Array.XArray (XArray(..))
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Strided.Arith
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
+deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
+
+instance (Show a, Elt a) => Show (Ranked n a) where
+ showsPrec d arr@(Ranked marr) =
+ let sh = show (toList (rshape arr))
+ in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
+
+instance Elt a => NFData (Ranked n a) where
+ rnf (Ranked arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+ deriving (Generic)
+
+deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance Elt a => Elt (Ranked n a) where
+ mshape (M_Ranked arr) = mshape arr
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
+ @(NonEmpty (Mixed sh2 (Ranked n a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
+ mconcat l = M_Ranked (mconcat (coerce l))
+
+ mrnf (M_Ranked arr) = mrnf arr
+
+ type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+
+ mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Ranked arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
+
+liftRanked1 :: forall n a b.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
+ -> Ranked n a -> Ranked n b
+liftRanked1 = coerce
+
+liftRanked2 :: forall n a b c.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
+ -> Ranked n a -> Ranked n b -> Ranked n c
+liftRanked2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+ (+) = liftRanked2 (+)
+ (-) = liftRanked2 (-)
+ (*) = liftRanked2 (*)
+ negate = liftRanked1 negate
+ abs = liftRanked1 abs
+ signum = liftRanked1 signum
+ fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
+ recip = liftRanked1 recip
+ (/) = liftRanked2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
+ exp = liftRanked1 exp
+ log = liftRanked1 log
+ sqrt = liftRanked1 sqrt
+ (**) = liftRanked2 (**)
+ logBase = liftRanked2 logBase
+ sin = liftRanked1 sin
+ cos = liftRanked1 cos
+ tan = liftRanked1 tan
+ asin = liftRanked1 asin
+ acos = liftRanked1 acos
+ atan = liftRanked1 atan
+ sinh = liftRanked1 sinh
+ cosh = liftRanked1 cosh
+ tanh = liftRanked1 tanh
+ asinh = liftRanked1 asinh
+ acosh = liftRanked1 acosh
+ atanh = liftRanked1 atanh
+ log1p = liftRanked1 GHC.Float.log1p
+ expm1 = liftRanked1 GHC.Float.expm1
+ log1pexp = liftRanked1 GHC.Float.log1pexp
+ log1mexp = liftRanked1 GHC.Float.log1mexp
+
+rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+rquotArray = liftRanked2 mquotArray
+rremArray = liftRanked2 mremArray
+
+ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+ratan2Array = liftRanked2 matan2Array
+
+
+rshape :: Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = shCvtXR' (mshape arr)
+
+rrank :: Elt a => Ranked n a -> SNat n
+rrank = shrRank . rshape
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 4bccbc4..c4c61fd 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -1,41 +1,29 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Shaped where
+module Data.Array.Nested.Shaped (
+ module Data.Array.Nested.Shaped.Base,
+ module Data.Array.Nested.Shaped,
+) where
import Prelude hiding (mappend, mconcat)
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
import Data.Array.Internal.ShapedG qualified as SG
import Data.Array.Internal.ShapedS qualified as SS
import Data.Bifunctor (first)
import Data.Coerce (coerce)
-import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
import GHC.TypeLits
import Data.Array.Mixed.Lemmas
@@ -44,211 +32,17 @@ import Data.Array.Mixed.Types
import Data.Array.XArray (XArray)
import Data.Array.XArray qualified as X
import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
import Data.Array.Strided.Arith
--- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's. Note that
--- these are "GHC.TypeLits" naturals, because we do not need induction over
--- them and we want very large arrays to be possible.
---
--- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
--- and 'Shaped' itself is again an instance of 'Elt' as well.
---
--- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
-deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
-
-instance (Show a, Elt a) => Show (Shaped n a) where
- showsPrec d arr@(Shaped marr) =
- let sh = show (shsToList (sshape arr))
- in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Shaped sh a) where
- rnf (Shaped arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
-
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
-
-instance Elt a => Elt (Shaped sh a) where
- mshape (M_Shaped arr) = mshape arr
- mindex (M_Shaped arr) i = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
-
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
-
- mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoListOuter (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift ssh2 f (M_Shaped arr) =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
- coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
- @(NonEmpty (Mixed sh2 (Shaped sh a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
-
- mconcat l = M_Shaped (mconcat (coerce l))
-
- mrnf (M_Shaped arr) = mrnf arr
-
- type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-
- mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Shaped arr) = marrayStrides arr
-
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
-instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
- memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
-
-liftShaped1 :: forall sh a b.
- (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
- -> Shaped sh a -> Shaped sh b
-liftShaped1 = coerce
-
-liftShaped2 :: forall sh a b c.
- (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
- -> Shaped sh a -> Shaped sh b -> Shaped sh c
-liftShaped2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
- (+) = liftShaped2 (+)
- (-) = liftShaped2 (-)
- (*) = liftShaped2 (*)
- negate = liftShaped1 negate
- abs = liftShaped1 abs
- signum = liftShaped1 signum
- fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
- recip = liftShaped1 recip
- (/) = liftShaped2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
- exp = liftShaped1 exp
- log = liftShaped1 log
- sqrt = liftShaped1 sqrt
- (**) = liftShaped2 (**)
- logBase = liftShaped2 logBase
- sin = liftShaped1 sin
- cos = liftShaped1 cos
- tan = liftShaped1 tan
- asin = liftShaped1 asin
- acos = liftShaped1 acos
- atan = liftShaped1 atan
- sinh = liftShaped1 sinh
- cosh = liftShaped1 cosh
- tanh = liftShaped1 tanh
- asinh = liftShaped1 asinh
- acosh = liftShaped1 acosh
- atanh = liftShaped1 atanh
- log1p = liftShaped1 GHC.Float.log1p
- expm1 = liftShaped1 GHC.Float.expm1
- log1pexp = liftShaped1 GHC.Float.log1pexp
- log1mexp = liftShaped1 GHC.Float.log1mexp
-
-squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
-squotArray = liftShaped2 mquotArray
-sremArray = liftShaped2 mremArray
-
-satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
-satan2Array = liftShaped2 matan2Array
-
-
semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
semptyArray sh = Shaped (memptyArray (shCvtSX sh))
-sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shCvtXS' (mshape arr)
-
srank :: Elt a => Shaped sh a -> SNat (Rank sh)
srank = shsRank . sshape
@@ -476,20 +270,3 @@ sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
-
-mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => Mixed sh a -> ShS sh' -> Shaped sh' a
-mcastToShaped arr targetsh
- | Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
-
-stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
-stoMixed (Shaped arr) = arr
-
--- | A more weakly-typed version of 'stoMixed' that does a runtime shape
--- compatibility check.
-scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => StaticShX sh' -> Shaped sh a -> Mixed sh' a
-scastToMixed sshx sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- = mcast sshx arr
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
new file mode 100644
index 0000000..74c231d
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -0,0 +1,235 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Shaped.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Mixed.Types
+import Data.Array.XArray (XArray)
+import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Strided.Arith
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
+
+instance (Show a, Elt a) => Show (Shaped n a) where
+ showsPrec d arr@(Shaped marr) =
+ let sh = show (shsToList (sshape arr))
+ in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
+
+instance Elt a => NFData (Shaped sh a) where
+ rnf (Shaped arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+ deriving (Generic)
+
+deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
+
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
+ @(NonEmpty (Mixed sh2 (Shaped sh a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ mconcat l = M_Shaped (mconcat (coerce l))
+
+ mrnf (M_Shaped arr) = mrnf arr
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Shaped arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+liftShaped1 :: forall sh a b.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
+ -> Shaped sh a -> Shaped sh b
+liftShaped1 = coerce
+
+liftShaped2 :: forall sh a b c.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
+ -> Shaped sh a -> Shaped sh b -> Shaped sh c
+liftShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = liftShaped2 (+)
+ (-) = liftShaped2 (-)
+ (*) = liftShaped2 (*)
+ negate = liftShaped1 negate
+ abs = liftShaped1 abs
+ signum = liftShaped1 signum
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ recip = liftShaped1 recip
+ (/) = liftShaped2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ exp = liftShaped1 exp
+ log = liftShaped1 log
+ sqrt = liftShaped1 sqrt
+ (**) = liftShaped2 (**)
+ logBase = liftShaped2 logBase
+ sin = liftShaped1 sin
+ cos = liftShaped1 cos
+ tan = liftShaped1 tan
+ asin = liftShaped1 asin
+ acos = liftShaped1 acos
+ atan = liftShaped1 atan
+ sinh = liftShaped1 sinh
+ cosh = liftShaped1 cosh
+ tanh = liftShaped1 tanh
+ asinh = liftShaped1 asinh
+ acosh = liftShaped1 acosh
+ atanh = liftShaped1 atanh
+ log1p = liftShaped1 GHC.Float.log1p
+ expm1 = liftShaped1 GHC.Float.expm1
+ log1pexp = liftShaped1 GHC.Float.log1pexp
+ log1mexp = liftShaped1 GHC.Float.log1mexp
+
+squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+squotArray = liftShaped2 mquotArray
+sremArray = liftShaped2 mremArray
+
+satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+satan2Array = liftShaped2 matan2Array
+
+
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shCvtXS' (mshape arr)