diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 23:44:26 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 23:45:52 +0200 | 
| commit | cb8a20c32e2737c28fa2993fb29ede9c0faa000d (patch) | |
| tree | 97abc964008e96f51cc6e2cfc5f60340406c3d9b /src/Data/Array/Nested/Ranked | |
| parent | 5f1213fc9e464ec361e6543884968980dd28457d (diff) | |
Move casts to DAN.Convert; split Ranked/Shaped types into .Base
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 242 | 
1 files changed, 242 insertions, 0 deletions
| diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs new file mode 100644 index 0000000..f827187 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Ranked.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Foldable (toList) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types +import Data.Array.XArray (XArray(..)) +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Strided.Arith + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) + +instance (Show a, Elt a) => Show (Ranked n a) where +  showsPrec d arr@(Ranked marr) = +    let sh = show (toList (rshape arr)) +    in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr + +instance Elt a => NFData (Ranked n a) where +  rnf (Ranked arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +  deriving (Generic) + +deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where +  mshape (M_Ranked arr) = mshape arr +  mindex (M_Ranked arr) i = Ranked (mindex arr i) + +  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) +  mindexPartial (M_Ranked arr) i = +    coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ +        mindexPartial arr i + +  mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + +  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) +  mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + +  mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] +  mtoListOuter (M_Ranked arr) = +    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + +  mlift :: forall sh1 sh2. +           StaticShX sh2 +        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) +  mlift ssh2 f (M_Ranked arr) = +    coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ +      mlift ssh2 f arr + +  mlift2 :: forall sh1 sh2 sh3. +            StaticShX sh3 +         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) +  mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = +    coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ +      mlift2 ssh3 f arr1 arr2 + +  mliftL :: forall sh1 sh2. +            StaticShX sh2 +         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) +         -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) +  mliftL ssh2 f l = +    coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) +           @(NonEmpty (Mixed sh2 (Ranked n a))) $ +      mliftL ssh2 f (coerce l) + +  mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) + +  mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + +  mconcat l = M_Ranked (mconcat (coerce l)) + +  mrnf (M_Ranked arr) = mrnf arr + +  type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + +  mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + +  mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + +  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + +  marrayStrides (M_Ranked arr) = marrayStrides arr + +  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () +  mvecsWrite sh idx (Ranked arr) vecs = +    mvecsWrite sh idx arr +      (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) +         vecs) + +  mvecsWritePartial :: forall sh sh' s. +                       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) +                    -> MixedVecs s (sh ++ sh') (Ranked n a) +                    -> ST s () +  mvecsWritePartial sh idx arr vecs = +    mvecsWritePartial sh idx +      (coerce @(Mixed sh' (Ranked n a)) +              @(Mixed sh' (Mixed (Replicate n Nothing) a)) +         arr) +      (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) +              @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) +         vecs) + +  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) +  mvecsFreeze sh vecs = +    coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) +           @(Mixed sh (Ranked n a)) +      <$> mvecsFreeze sh +            (coerce @(MixedVecs s sh (Ranked n a)) +                    @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) +                    vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where +  memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) +  memptyArrayUnsafe i +    | Dict <- lemKnownReplicate (SNat @n) +    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ +        memptyArrayUnsafe i + +  mvecsUnsafeNew idx (Ranked arr) +    | Dict <- lemKnownReplicate (SNat @n) +    = MV_Ranked <$> mvecsUnsafeNew idx arr + +  mvecsNewEmpty _ +    | Dict <- lemKnownReplicate (SNat @n) +    = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +liftRanked1 :: forall n a b. +               (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) +            -> Ranked n a -> Ranked n b +liftRanked1 = coerce + +liftRanked2 :: forall n a b c. +               (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) +            -> Ranked n a -> Ranked n b -> Ranked n c +liftRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where +  (+) = liftRanked2 (+) +  (-) = liftRanked2 (-) +  (*) = liftRanked2 (*) +  negate = liftRanked1 negate +  abs = liftRanked1 abs +  signum = liftRanked1 signum +  fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where +  fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" +  recip = liftRanked1 recip +  (/) = liftRanked2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where +  pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" +  exp = liftRanked1 exp +  log = liftRanked1 log +  sqrt = liftRanked1 sqrt +  (**) = liftRanked2 (**) +  logBase = liftRanked2 logBase +  sin = liftRanked1 sin +  cos = liftRanked1 cos +  tan = liftRanked1 tan +  asin = liftRanked1 asin +  acos = liftRanked1 acos +  atan = liftRanked1 atan +  sinh = liftRanked1 sinh +  cosh = liftRanked1 cosh +  tanh = liftRanked1 tanh +  asinh = liftRanked1 asinh +  acosh = liftRanked1 acosh +  atanh = liftRanked1 atanh +  log1p = liftRanked1 GHC.Float.log1p +  expm1 = liftRanked1 GHC.Float.expm1 +  log1pexp = liftRanked1 GHC.Float.log1pexp +  log1mexp = liftRanked1 GHC.Float.log1mexp + +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = liftRanked2 mquotArray +rremArray = liftRanked2 mremArray + +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = liftRanked2 matan2Array + + +rshape :: Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shCvtXR' (mshape arr) + +rrank :: Elt a => Ranked n a -> SNat n +rrank = shrRank . rshape | 
