{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Shaped where import Prelude hiding (mappend, mconcat) import Control.DeepSeq (NFData) import Control.Monad.ST import Data.Bifunctor (first) import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) import GHC.TypeLits import Data.Array.Mixed.XArray (XArray) import Data.Array.Mixed.XArray qualified as X import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape -- | A shape-typed array: the full shape of the array (the sizes of its -- dimensions) is represented on the type level as a list of 'Nat's. Note that -- these are "GHC.TypeLits" naturals, because we do not need induction over -- them and we want very large arrays to be possible. -- -- Like for 'Ranked', the valid elements are described by the 'Elt' type class, -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a) -- just unwrap the newtype and defer to the general instance for nested arrays newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance Elt a => Elt (Shaped sh a) where mshape (M_Shaped arr) = mshape arr mindex (M_Shaped arr) i = Shaped (mindex arr i) mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mindexPartial arr i mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] mtoListOuter (M_Shaped arr) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) mlift ssh2 f (M_Shaped arr) = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift ssh2 f arr mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) mliftL ssh2 f l = coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) @(NonEmpty (Mixed sh2 (Shaped sh a))) $ mliftL ssh2 f (coerce l) mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) mconcat l = M_Shaped (mconcat (coerce l)) type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs = mvecsWrite sh idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) mvecsWritePartial :: forall sh1 sh2 s. IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs = mvecsWritePartial sh idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) <$> mvecsFreeze sh (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ memptyArray i mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr mvecsNewEmpty _ | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) arithPromoteShaped :: forall sh a. PrimElt a => (forall shx. Mixed shx a -> Mixed shx a) -> Shaped sh a -> Shaped sh a arithPromoteShaped = coerce arithPromoteShaped2 :: forall sh a. PrimElt a => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) -> Shaped sh a -> Shaped sh a -> Shaped sh a arithPromoteShaped2 = coerce instance (NumElt a, PrimElt a) => Num (Shaped sh a) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" exp = arithPromoteShaped exp log = arithPromoteShaped log sqrt = arithPromoteShaped sqrt (**) = arithPromoteShaped2 (**) logBase = arithPromoteShaped2 logBase sin = arithPromoteShaped sin cos = arithPromoteShaped cos tan = arithPromoteShaped tan asin = arithPromoteShaped asin acos = arithPromoteShaped acos atan = arithPromoteShaped atan sinh = arithPromoteShaped sinh cosh = arithPromoteShaped cosh tanh = arithPromoteShaped tanh asinh = arithPromoteShaped asinh acosh = arithPromoteShaped acosh atanh = arithPromoteShaped atanh log1p = arithPromoteShaped GHC.Float.log1p expm1 = arithPromoteShaped GHC.Float.expm1 log1pexp = arithPromoteShaped GHC.Float.log1pexp log1mexp = arithPromoteShaped GHC.Float.log1mexp sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = shCvtXS' (mshape arr) sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh shsTakeIx _ _ ZIS = ZSS shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) (ixCvtSX idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemTakeLenMapJust perm (sshape sarr) , Refl <- lemDropLenMapJust perm (sshape sarr) , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] stoListOuter (Shaped arr) = coerce (mtoListOuter arr) stoList1 :: Elt a => Shaped '[n] a -> [a] stoList1 = map sunScalar . stoListOuter sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromListPrim sn l | Refl <- lemAppNil @'[Just n] = let ssh = SUnknown () :!% ZKX xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a sfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) snest sh arr | Refl <- lemMapJustApp sh (Proxy @sh') = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') = Shaped arr srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) => ShS sh -> ShS sh2 -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh1) , Refl <- lemMapJustApp sh (Proxy @sh2) = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) (shCvtSX sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => ShS sh -> ShS sh2 -> (Shaped sh1 a -> Shaped sh2 b) -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b srerank sh sh2 f (stoPrimitive -> arr) = sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') = Shaped (mreplicate (shCvtSX sh) arr) sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a sslice i n@SNat arr = let _ :$$ sh = sshape arr in slift (n :$$ sh) (\_ -> X.slice i n) arr srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a mcastToShaped arr targetsh | Refl <- lemAppNil @sh , Refl <- lemAppNil @(MapJust sh') , Refl <- lemRankMapJust targetsh = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)