From 8b59d8ef4ff97936f2a753d1ce345e0404c26b2b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 May 2024 22:47:52 +0200 Subject: Clearer module purposes Thanks Mikolaj for discussion --- src/Data/Array/Mixed.hs | 335 -------------- src/Data/Array/Mixed/Permutation.hs | 5 + src/Data/Array/Mixed/XArray.hs | 331 +++++++++++++ src/Data/Array/Nested.hs | 13 +- src/Data/Array/Nested/Convert.hs | 28 -- src/Data/Array/Nested/Internal/Convert.hs | 28 ++ src/Data/Array/Nested/Internal/Lemmas.hs | 59 +++ src/Data/Array/Nested/Internal/Mixed.hs | 741 ++++++++++++++++++++++++++++++ src/Data/Array/Nested/Internal/Ranked.hs | 446 ++++++++++++++++++ src/Data/Array/Nested/Internal/Shape.hs | 467 +++++++++++++++++++ src/Data/Array/Nested/Internal/Shaped.hs | 379 +++++++++++++++ src/Data/Array/Nested/Lemmas.hs | 59 --- src/Data/Array/Nested/Mixed.hs | 741 ------------------------------ src/Data/Array/Nested/Ranked.hs | 446 ------------------ src/Data/Array/Nested/Shape.hs | 467 ------------------- src/Data/Array/Nested/Shaped.hs | 379 --------------- 16 files changed, 2462 insertions(+), 2462 deletions(-) delete mode 100644 src/Data/Array/Mixed.hs create mode 100644 src/Data/Array/Mixed/XArray.hs delete mode 100644 src/Data/Array/Nested/Convert.hs create mode 100644 src/Data/Array/Nested/Internal/Convert.hs create mode 100644 src/Data/Array/Nested/Internal/Lemmas.hs create mode 100644 src/Data/Array/Nested/Internal/Mixed.hs create mode 100644 src/Data/Array/Nested/Internal/Ranked.hs create mode 100644 src/Data/Array/Nested/Internal/Shape.hs create mode 100644 src/Data/Array/Nested/Internal/Shaped.hs delete mode 100644 src/Data/Array/Nested/Lemmas.hs delete mode 100644 src/Data/Array/Nested/Mixed.hs delete mode 100644 src/Data/Array/Nested/Ranked.hs delete mode 100644 src/Data/Array/Nested/Shape.hs delete mode 100644 src/Data/Array/Nested/Shaped.hs (limited to 'src/Data') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs deleted file mode 100644 index 4a338a2..0000000 --- a/src/Data/Array/Mixed.hs +++ /dev/null @@ -1,335 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed where - -import Control.DeepSeq (NFData(..)) -import Data.Array.Ranked qualified as ORB -import Data.Array.RankedS qualified as S -import Data.Coerce -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import Data.Type.Ord -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types - - -type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (Rank sh) a) - deriving (Show, Eq, Generic) - --- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. -deriving instance (Ord a, Storable a) => Ord (XArray '[] a) - -instance NFData a => NFData (XArray sh a) - - -shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh -shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKX [] = ZSX - go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shxToList sh) v) - -toVector :: Storable a => XArray sh a -> VS.Vector a -toVector (XArray arr) = S.toVector arr - -scalar :: Storable a => a -> XArray '[] a -scalar = XArray . S.scalar - --- | Will throw if the array does not have the casted-to shape. -cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> StaticShX sh' - -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -cast ssh1 sh2 ssh' (XArray arr) - | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (ssxFromShape sh2) ssh' - = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if shxToList arrsh == shxToList sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" - -unScalar :: Storable a => XArray '[] a -> a -unScalar (XArray a) = S.unScalar a - -replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a -replicate sh ssh' (XArray arr) - | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') - , Refl <- lemRankApp (ssxFromShape sh) ssh' - = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ - arr) - -replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicateScal sh x - | Dict <- lemKnownNatRank sh - = XArray (S.constant (shxToList sh) x) - -generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) - --- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shxShapeL sh) --- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) - -indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a -indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx - -index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' - -append :: forall n m sh a. Storable a - => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append ssh (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX ssh - = XArray (S.append a b) - --- | If the prefix of the shape of the input array (@sh@) is empty (i.e. --- contains a zero), then there is no way to deduce the full shape of the output --- array (more precisely, the @sh2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the shape with zeros wherever we cannot deduce --- what it should be. --- --- For example, if: --- --- @ --- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] --- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float --- @ --- --- then: --- --- @ --- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float --- @ --- --- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ --- in this shape: we don't know if @f@ intended to return an array with shape 0 --- here (it probably didn't), but there is no better number to put here absent --- a subarray of the input to pass to @f@. --- --- In this particular case the fact that @sh@ is empty was evident from the --- type-level information, but the same situation occurs when @sh@ consists of --- @Nothing@s, and some of those happen to be zero at runtime. -rerank :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f xarr@(XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if any (== 0) (shxToList sh) - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a -> let XArray r = f (XArray a) in r) - arr) - -rerankTop :: forall sh1 sh2 sh a b. - (Storable a, Storable b) - => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b -rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh - --- | The caveat about empty arrays at @rerank@ applies here too. -rerank2 :: forall sh sh1 sh2 a b c. - (Storable a, Storable b, Storable c) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if any (== 0) (shxToList sh) - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a b -> let XArray r = f (XArray a) (XArray b) in r) - arr1 arr2) - -class KnownNatList l where makeNatList :: Perm l -instance KnownNatList '[] where makeNatList = PNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList - --- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) - => StaticShX sh - -> Perm is - -> XArray sh a - -> XArray (PermutePrefix is sh) a -transpose ssh perm (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) - , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm - , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList' perm) arr) - --- | The list argument gives indices into the original dimension list. --- --- The permutation (the list) must have length <= @n@. If it is longer, this --- function throws. -transposeUntyped :: forall n sh a. - SNat n -> StaticShX sh -> [Int] - -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a -transposeUntyped sn ssh perm (XArray arr) - | length perm <= fromSNat' sn - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) - = XArray (S.transpose perm arr) - | otherwise - = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" - -transpose2 :: forall sh1 sh2 a. - StaticShX sh1 -> StaticShX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (Storable a, NumElt a) => XArray sh a -> a -sumFull (XArray arr) = - S.unScalar $ - numEltSum1Inner (SNat @0) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr - -sumInner :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' arr - | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = shxFlatten sh' :$% ZSX - ssh'F = ssxFromShape sh'F - - go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a - go (XArray arr') - | Refl <- lemRankApp ssh ssh'F - , let sn = listxLengthSNat (let StaticShX l = ssh in l) - = XArray (numEltSum1Inner sn arr') - - in go $ - transpose2 ssh'F ssh $ - reshapePartial ssh' ssh sh'F $ - transpose2 ssh ssh' $ - arr - -sumOuter :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' arr - | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = shxFlatten sh :$% ZSX - in sumInner ssh' (ssxFromShape shF) $ - transpose2 (ssxFromShape shF) ssh' $ - reshapePartial ssh ssh' shF $ - arr - -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh - = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) - -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = - case S.shapeL arr of - 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) - -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) - -toList1 :: Storable a => XArray '[n] a -> [a] -toList1 (XArray arr) = S.toList arr - --- | Throws if the given shape is not, in fact, empty. -empty :: forall sh a. Storable a => IShX sh -> XArray sh a -empty sh - | Dict <- lemKnownNatRank sh - = XArray (S.constant (shxToList sh) - (error "Data.Array.Mixed.empty: shape was not empty")) - -slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a -slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) - -sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a -sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) - -rev1 :: XArray (n : sh) a -> XArray (n : sh) a -rev1 (XArray arr) = XArray (S.rev [0] arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a -reshape ssh1 sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX ssh1 - , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shxToList sh2) arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -reshapePartial ssh1 ssh' sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') - = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) - --- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). -iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a -iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 83a5ee4..6ff3bdc 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -61,6 +61,11 @@ permToList (x `PCons` l) = TN.fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList +-- | Utility class for generating permutations from type class information. +class KnownPerm l where makePerm :: Perm l +instance KnownPerm '[] where makePerm = PNil +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm + -- ** Applying permutations diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs new file mode 100644 index 0000000..cc0a6a5 --- /dev/null +++ b/src/Data/Array/Mixed/XArray.hs @@ -0,0 +1,331 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.XArray where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +type XArray :: [Maybe Nat] -> Type -> Type +newtype XArray sh a = XArray (S.Array (Rank sh) a) + deriving (Show, Eq, Generic) + +-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. +deriving instance (Ord a, Storable a) => Ord (XArray '[] a) + +instance NFData a => NFData (XArray sh a) + + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + = XArray (S.fromVector (shxToList sh) v) + +toVector :: Storable a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr + +scalar :: Storable a => a -> XArray '[] a +scalar = XArray . S.scalar + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> StaticShX sh' + -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) + | Refl <- lemRankApp ssh1 ssh' + , Refl <- lemRankApp (ssxFromShape sh2) ssh' + = let arrsh :: IShX sh1 + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in if shxToList arrsh == shxToList sh2 + then XArray arr + else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + +unScalar :: Storable a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a + +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') + , Refl <- lemRankApp (ssxFromShape sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) x) + +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) + +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) + +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a +indexPartial (XArray arr) ZIX = XArray arr +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx + +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in S.unScalar arr' + +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.append a b) + +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +rerank :: forall sh sh1 sh2 a b. + (Storable a, Storable b) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f xarr@(XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a -> let XArray r = f (XArray a) in r) + arr) + +rerankTop :: forall sh1 sh2 sh a b. + (Storable a, Storable b) + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + +-- | The caveat about empty arrays at @rerank@ applies here too. +rerank2 :: forall sh sh1 sh2 a b c. + (Storable a, Storable b, Storable c) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a b -> let XArray r = f (XArray a) (XArray b) in r) + arr1 arr2) + +-- | The list argument gives indices into the original dimension list. +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) + => StaticShX sh + -> Perm is + -> XArray sh a + -> XArray (PermutePrefix is sh) a +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen ssh perm + = XArray (S.transpose (permToList' perm) arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) + = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" + +transpose2 :: forall sh1 sh2 a. + StaticShX sh1 -> StaticShX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (Storable a, NumElt a) => XArray sh a -> a +sumFull (XArray arr) = + S.unScalar $ + numEltSum1Inner (SNat @0) $ + S.fromVector [product (S.shapeL arr)] $ + S.toVector arr + +sumInner :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' arr + | Refl <- lemAppNil @sh + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShape sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = listxLengthSNat (let StaticShX l = ssh in l) + = XArray (numEltSum1Inner sn arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' arr + | Refl <- lemAppNil @sh + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShape shF) $ + transpose2 (ssxFromShape shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr + +fromListOuter :: forall n sh a. Storable a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l + | Dict <- lemKnownNatRankSSX ssh + = case ssh of + SKnown m :!% _ | fromSNat' m /= length l -> + error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = + case S.shapeL arr of + 0 : _ -> [] + _ -> coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = + let n = length l + in case ssh of + SKnown m :!% _ | fromSNat' m /= n -> + error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr + +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) + (error "Data.Array.Mixed.empty: shape was not empty")) + +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) + +rev1 :: XArray (n : sh) a -> XArray (n : sh) a +rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shxToList sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index c982b4d..4d2d616 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -61,21 +61,20 @@ module Data.Array.Nested ( pattern SZ, pattern SS, Perm(..), IsPermutation, - KnownNatList(..), + KnownPerm(..), NumElt, FloatElt, ) where import Prelude hiding (mappend) -import Data.Array.Mixed import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -import Data.Array.Nested.Convert -import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shape -import Data.Array.Nested.Shaped +import Data.Array.Nested.Internal.Convert +import Data.Array.Nested.Internal.Mixed +import Data.Array.Nested.Internal.Ranked +import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Internal.Shaped import Foreign.Storable import GHC.TypeLits diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs deleted file mode 100644 index cb22c32..0000000 --- a/src/Data/Array/Nested/Convert.hs +++ /dev/null @@ -1,28 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Convert where - -import Data.Type.Equality - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Shape -import Data.Array.Nested.Lemmas -import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shape -import Data.Array.Nested.Shaped - - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mtoRanked arr - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) - , Refl <- lemRankMapJust targetsh - = mcastToShaped arr targetsh diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs new file mode 100644 index 0000000..e101981 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -0,0 +1,28 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Internal.Convert where + +import Data.Type.Equality + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Shape +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Internal.Mixed +import Data.Array.Nested.Internal.Ranked +import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Internal.Shaped + + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped arr targetsh diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs new file mode 100644 index 0000000..5ce5373 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Lemmas.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Internal.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Nested.Internal.Shape + + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl + +lemMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemMapJustTakeLen PNil _ = Refl +lemMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustTakeLen is sh = Refl +lemMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemMapJustDropLen PNil _ = Refl +lemMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustDropLen is sh = Refl +lemMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" + +lemMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemMapJustIndex SZ (_ :$$ _) = Refl +lemMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemMapJustIndex i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemMapJustIndex _ ZSS = error "Index of empty" + +lemMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemMapJustPermute PNil _ = Refl +lemMapJustPermute (i `PCons` is) sh + | Refl <- lemMapJustPermute is sh + , Refl <- lemMapJustIndex i sh + = Refl + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs new file mode 100644 index 0000000..98871d5 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -0,0 +1,741 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Internal.Mixed where + +import Control.DeepSeq (NFData) +import Control.Monad (forM_, when) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Foldable (toList) +import Data.Int +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.XArray (XArray(..)) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Lemmas + + +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + +-- Primitive element types +-- ======================= +-- +-- There are a few primitive element types; arrays containing elements of such +-- type are a newtype over an XArray, which it itself a newtype over a Vector. +-- Unfortunately, the setup of the library requires us to list these primitive +-- element types multiple times; to aid in extending the list, all these lists +-- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class Storable a => PrimElt a where + fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a + toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce + + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Int +instance PrimElt Int64 +instance PrimElt Int32 +instance PrimElt CInt +instance PrimElt Float +instance PrimElt Double +instance PrimElt () + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'mtranspose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, you might see (nonempty) +-- sizes of the elements of an empty array, which is information that should +-- ostensibly not exist; the full array is still empty. + +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) + deriving (Show, Eq, Generic) + +-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. +deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) + +instance NFData a => NFData (Mixed sh (Primitive a)) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector) +-- etc. + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) +deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) +deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) +deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) +deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) +deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) +deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) + +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) +-- etc., larger tuples (perhaps use generics to allow arbitrary product types) + +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) +deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) +instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) + + +-- | Internal helper data family mirroring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) +newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) + + +mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) + +mliftNumElt2 :: PrimElt a + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 numEltAdd + (-) = mliftNumElt2 numEltSub + (*) = mliftNumElt2 numEltMul + negate = mliftNumElt1 numEltNeg + abs = mliftNumElt1 numEltAbs + signum = mliftNumElt1 numEltSignum + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + recip = mliftNumElt1 floatEltRecip + (/) = mliftNumElt2 floatEltDiv + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + exp = mliftNumElt1 floatEltExp + log = mliftNumElt1 floatEltLog + sqrt = mliftNumElt1 floatEltSqrt + + (**) = mliftNumElt2 floatEltPow + logBase = mliftNumElt2 floatEltLogbase + + sin = mliftNumElt1 floatEltSin + cos = mliftNumElt1 floatEltCos + tan = mliftNumElt1 floatEltTan + asin = mliftNumElt1 floatEltAsin + acos = mliftNumElt1 floatEltAcos + atan = mliftNumElt1 floatEltAtan + sinh = mliftNumElt1 floatEltSinh + cosh = mliftNumElt1 floatEltCosh + tanh = mliftNumElt1 floatEltTanh + asinh = mliftNumElt1 floatEltAsinh + acosh = mliftNumElt1 floatEltAcosh + atanh = mliftNumElt1 floatEltAtanh + log1p = mliftNumElt1 floatEltLog1p + expm1 = mliftNumElt1 floatEltExpm1 + log1pexp = mliftNumElt1 floatEltLog1pexp + log1mexp = mliftNumElt1 floatEltLog1mexp + + +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where + -- ====== PUBLIC METHODS ====== -- + + mshape :: Mixed sh a -> IShX sh + mindex :: Mixed sh a -> IIxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a + + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] + + -- | Note: this library makes no particular guarantees about the shapes of + -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the + -- full 'XArray' and as such you can distinguish different empty arrays by + -- the "shapes" of their elements. This information is meaningless, so you + -- should not use it. + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a + + -- | See the documentation for 'mlift'. + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + + -- ====== PRIVATE METHODS ====== -- + + -- | Tree giving the shape of every array component. + type ShapeTree a + + mshapeTree :: a -> ShapeTree a + + mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + + mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + + mshowShapeTree :: Proxy a -> ShapeTree a -> String + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for 'mgenerate' / 'rgenerate' / +-- 'sgenerate'. +class Elt a => KnownElt a where + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArray :: IShX sh -> Mixed sh a + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance Storable a => Elt (Primitive a) where + mshape (M_Primitive sh _) = sh + mindex (M_Primitive _ a) i = Primitive (X.index a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) + mfromListOuter l@(arr1 :| _) = + let sh = SUnknown (length l) :$% mshape arr1 + in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) + mlift ssh2 f (M_Primitive _ a) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , let result = f ZKX a + = M_Primitive (X.shape ssh2 result) result + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) + mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 + , let result = f ZKX a b + = M_Primitive (X.shape ssh3 result) result + + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcast ssh1 sh2 _ (M_Primitive sh1' arr) = + let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + + mtranspose perm (M_Primitive sh arr) = + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShape sh) perm arr) + + type ShapeTree (Primitive a) = () + mshapeTree _ = () + mshapeTreeEq _ () () = True + mshapeTreeEmpty _ () = False + mshowShapeTree _ () = "()" + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + + -- TODO: this use of toVector is suboptimal + mvecsWritePartial + :: forall sh' sh s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShape sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance Elt Int +deriving via Primitive Int64 instance Elt Int64 +deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive CInt instance Elt CInt +deriving via Primitive Double instance Elt Double +deriving via Primitive Float instance Elt Float +deriving via Primitive () instance Elt () + +instance Storable a => KnownElt (Primitive a) where + memptyArray sh = M_Primitive sh (X.empty sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Int64 instance KnownElt Int64 +deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive CInt instance KnownElt CInt +deriving via Primitive Double instance KnownElt Double +deriving via Primitive Float instance KnownElt Float +deriving via Primitive () instance KnownElt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + + mcast ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + + type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) + mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' + mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) + +-- Arrays of arrays are just arrays, but with more dimensions. +instance Elt a => Elt (Mixed sh' a) where + -- TODO: this is quadratic in the nesting depth because it repeatedly + -- truncates the shape vector to one a little shorter. Fix with a + -- moverlongShape method, a prefix of which is mshape. + mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh + mshape (M_Nest sh arr) + = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + + mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a + mindex (M_Nest _ arr) i = mindexPartial arr i + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest sh arr) i + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + mscalar = M_Nest ZSX + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = + M_Nest (SUnknown (length l) :$% mshape arr) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) + mlift ssh2 f (M_Nest sh1 arr) = + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + in M_Nest sh2 result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) + mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + in M_Nest sh3 result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcast ssh1 sh2 _ (M_Nest sh1T arr) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = shxDropSh @sh @sh' (mshape arr) sh + , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (shxPermutePrefix perm sh) + (mtranspose perm arr) + + type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) + + mvecsUnsafeNew sh example + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) + where + sh' = mshape example + + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + + +-- | Create an array given a size and a function that computes the element at a +-- given index. +-- +-- __WARNING__: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> +-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> +-- > ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case shxEnum sh of + [] -> memptyArray sh + firstidx : restidxs -> + let firstelem = f (ixxZero' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArray sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs + +msumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = + let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX + in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + +mappend :: forall n m sh a. Elt a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 + where + sn :$% sh = mshape arr1 + sm :$% _ = mshape arr2 + ssh = ssxFromShape sh + snm :: SMayNat () SNat (AddMaybe n m) + snm = case (sn, sm) of + (SUnknown{}, _) -> SUnknown () + (SKnown{}, SUnknown{}) -> SUnknown () + (SKnown n, SKnown m) -> SKnown (snatPlus n m) + + f :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') + +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) + +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) + +mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a +mtoVectorP (M_Primitive _ v) = X.toVector v + +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a +mtoVector arr = mtoVectorP (toPrimitive arr) + +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? + +mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromList1Prim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) + +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr ZIX + +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) +mrerankP ssh sh2 f (M_Primitive sh arr) = + let sh1 = shxDropSSX sh ssh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) + +-- | See the caveats at @X.rerank@. +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b +mrerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = ssxFromShape (mshape arr) + in mlift (ssxAppend (ssxFromShape sh) ssh') + (\(sshT :: StaticShX shT) -> + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) + +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n arr = + let _ :$% sh = mshape arr + in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr + +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr + +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' arr = + mlift (ssxFromShape sh') + (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') + arr + +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + +masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +masXArrayPrimP (M_Primitive sh arr) = (sh, arr) + +masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) +masXArrayPrim = masXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr + +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP + +mliftPrim :: PrimElt a + => (a -> a) + -> Mixed sh a -> Mixed sh a +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) + +mliftPrim2 :: PrimElt a + => (a -> a -> a) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = + fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs new file mode 100644 index 0000000..d5bd70f --- /dev/null +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -0,0 +1,446 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Internal.Ranked where + +import Prelude hiding (mappend) + +import Control.DeepSeq (NFData) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed.XArray (XArray(..)) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Nested.Internal.Mixed +import Data.Array.Nested.Internal.Shape + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a) +deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a) + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 + + mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) + + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + + mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArray i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArray i + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +arithPromoteRanked :: forall n a. PrimElt a + => (forall sh. Mixed sh a -> Mixed sh a) + -> Ranked n a -> Ranked n a +arithPromoteRanked = coerce + +arithPromoteRanked2 :: forall n a. PrimElt a + => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) + -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where + (+) = arithPromoteRanked2 (+) + (-) = arithPromoteRanked2 (-) + (*) = arithPromoteRanked2 (*) + negate = arithPromoteRanked negate + abs = arithPromoteRanked abs + signum = arithPromoteRanked signum + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" + recip = arithPromoteRanked recip + (/) = arithPromoteRanked2 (/) + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" + exp = arithPromoteRanked exp + log = arithPromoteRanked log + sqrt = arithPromoteRanked sqrt + (**) = arithPromoteRanked2 (**) + logBase = arithPromoteRanked2 logBase + sin = arithPromoteRanked sin + cos = arithPromoteRanked cos + tan = arithPromoteRanked tan + asin = arithPromoteRanked asin + acos = arithPromoteRanked acos + atan = arithPromoteRanked atan + sinh = arithPromoteRanked sinh + cosh = arithPromoteRanked cosh + tanh = arithPromoteRanked tanh + asinh = arithPromoteRanked asinh + acosh = arithPromoteRanked acosh + atanh = arithPromoteRanked atanh + log1p = arithPromoteRanked GHC.Float.log1p + expm1 = arithPromoteRanked GHC.Float.expm1 + log1pexp = arithPromoteRanked GHC.Float.log1pexp + log1mexp = arithPromoteRanked GHC.Float.log1mexp + + +rshape :: forall n a. Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shCvtXR' (mshape arr) + +rindex :: Elt a => Ranked n a -> IIxR n -> a +rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (castWith (subst2 (lemReplicatePlusApp (ixrToSNat idx) (Proxy @m) (Proxy @Nothing))) arr) + (ixCvtRX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate sh f + | sn@SNat <- shrToSNat sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn + = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) + +-- | See the documentation of 'mlift'. +rlift :: forall n1 n2 a. Elt a + => SNat n2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) + +-- | See the documentation of 'mlift2'. +rlift2 :: forall n1 n2 n3 a. Elt a + => SNat n3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a +rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) + +rsumOuter1P :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (msumOuter1P arr) + +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive + +rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a +rtranspose perm arr + | sn@SNat <- shrToSNat (rshape arr) + , Dict <- lemKnownReplicate sn + , length perm <= fromIntegral (natVal (Proxy @n)) + = rlift sn + (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + arr + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rappend :: forall n a. Elt a + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend arr1 arr2 + | sn@SNat <- shrToSNat (rshape arr1) + , Dict <- lemKnownReplicate sn + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) + arr1 arr2 + +rscalar :: Elt a => a -> Ranked 0 a +rscalar x = Ranked (mscalar x) + +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v + | Dict <- lemKnownReplicate (shrToSNat sh) + = Ranked (mfromVectorP (shCvtRX sh) v) + +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = rfromPrimitive (rfromVectorP sh v) + +rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a +rtoVectorP = coerce mtoVectorP + +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a +rtoVector = coerce mtoVector + +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) + +rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a +rfromList1Prim l = Ranked (mfromList1Prim l) + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) + +rtoList1 :: Elt a => Ranked 1 a -> [a] +rtoList1 = map runScalar . rtoListOuter + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) + +rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a +rfromOrthotope sn arr + | Refl <- lemRankReplicate sn + = let xarr = XArray arr + in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) + +runScalar :: Elt a => Ranked 0 a -> a +runScalar arr = rindex arr ZIR + +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) + => SNat n -> IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) +rrerankP sn sh2 f (Ranked arr) + | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) + , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) + = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr) + +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 5 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => SNat n -> IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked (n + n1) a -> Ranked (n + n2) b +rrerank ssh sh2 f (rtoPrimitive -> arr) = + rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr + +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (shrToSNat sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shCvtRX sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x + | Dict <- lemKnownReplicate (shrToSNat sh) + = Ranked (mreplicateScalP (shCvtRX sh) x) + +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) + +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = rlift (shrToSNat (rshape arr)) + (\_ -> X.sliceU i n) + arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = + rlift (shrToSNat (rshape arr)) + (\(_ :: StaticShX sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) + arr + +rreshape :: forall n n' a. Elt a + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' rarr@(Ranked arr) + | Dict <- lemKnownReplicate (shrToSNat (rshape rarr)) + , Dict <- lemKnownReplicate (shrToSNat sh') + = Ranked (mreshape (shCvtRX sh') arr) + +riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a +riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota + +rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) + +rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) + +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) + +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) + +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked arr + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) + , Refl <- lemRankReplicate (shxRank (mshape arr)) + = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + where + convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) + convSh ZSX = ZSX + convSh (smn :$% (sh :: IShX sh'T)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) + = SUnknown (fromSMayNat' smn) :$% convSh sh diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs new file mode 100644 index 0000000..319f171 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -0,0 +1,467 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Internal.Shape where + +import Data.Array.Mixed.Types +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Kind (Type, Constraint) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape + + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +instance Show i => Show (ListR n i) where + showsPrec _ = listrShow shows + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListR sh' i -> ShowS + go _ ZR = id + go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrToSNat :: ListR n i -> SNat n +listrToSNat ZR = SNat +listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> + listrFromList perm $ \sperm -> + case (listrToSNat sperm, listrToSNat sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Eq, Ord) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +type IIxR n = IxR n Int + +instance Show i => Show (IxR n i) where + showsPrec _ (IxR l) = listrShow shows l + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixCvtXR :: IIxX sh -> IIxR (Rank sh) +ixCvtXR ZIX = ZIR +ixCvtXR (n :.% idx) = n :.: ixCvtXR idx + +ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) +ixCvtRX ZIR = ZIX +ixCvtRX (n :.: (idx :: IxR m Int)) = + castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixCvtRX idx) + +ixrToSNat :: IxR n i -> SNat n +ixrToSNat (IxR sh) = listrToSNat sh + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Eq, Ord) + deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: (ShR sh) = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +instance Show i => Show (ShR n i) where + showsPrec _ (ShR l) = listrShow shows l + +shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n +shCvtXR' ZSX = + castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) + ZSR +shCvtXR' (n :$% (idx :: IShX sh)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = + castWith (subst2 (lem1 @sh Refl)) + (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + where + lem1 :: forall sh' n' k. + k : sh' :~: Replicate n' Nothing + -> Rank sh' + 1 :~: n' + lem1 Refl = unsafeCoerceRefl + + lem2 :: k : sh :~: Replicate n Nothing + -> sh :~: Replicate (Rank sh) Nothing + lem2 Refl = unsafeCoerceRefl + +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: (idx :: ShR m Int)) = + castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shCvtRX idx) + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrToSNat :: ShR n i -> SNat n +shrToSNat (ShR sh) = listrToSNat sh + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList = go (SNat @n) + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error "IsList(ListR): Mismatched list length" + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList = Foldable.toList + + +type role ListS nominal representational +type ListS :: [Nat] -> (Nat -> Type) -> Type +data ListS sh f where + ZS :: ListS '[] f + -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity + (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +infixr 3 ::$ + +instance (forall n. Show (f n)) => Show (ListS sh f) where + showsPrec _ = listsShow shows + +data UnconsListSRes f sh1 = + forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) +listsUncons ZS = Nothing + +listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g +listsFmap _ ZS = ZS +listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs + +listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFold _ ZS = mempty +listsFold f (x ::$ xs) = f x <> listsFold f xs + +listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListS sh' f -> ShowS + go _ ZS = id + go prefix (x ::$ xs) = showString prefix . f x . go "," xs + +listsToList :: ListS sh (Const i) -> [i] +listsToList ZS = [] +listsToList (Const i ::$ is) = i : listsToList is + +listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend ZS idx' = idx' +listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' + +listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm PNil _ = ZS +listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh +listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm PNil sh = sh +listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh +listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = + case listsIndex (Proxy @is') (Proxy @sh) i sh of + (item, SNat) -> item ::$ listsPermute is sh + +-- TODO: remove this SNat when the KnownNat constaint in ListS is removed +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) +listsIndex _ _ SZ (n ::$ _) = (n, SNat) +listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listsIndex p pT i sh +listsIndex _ _ _ ZS = error "Index into empty shape" + +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = coerce (listsTakeLenPerm @SNat) + +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = coerce (listsPermute @SNat) + +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) +shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) + +applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) + +applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS = coerce (applyPermS @(Const i)) + +applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS = coerce (applyPermS @SNat) + + +-- | An index into a shape-typed array. +-- +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). Note that because the shape of a +-- shape-typed array is known statically, you can also retrieve the array shape +-- from a 'KnownShape' dictionary. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (ListS sh (Const i)) + deriving (Eq, Ord) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS = IxS ZS + +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) + where i :.$ IxS shl = IxS (Const i ::$ shl) +infixr 3 :.$ + +{-# COMPLETE ZIS, (:.$) #-} + +type IIxS sh = IxS sh Int + +instance Show i => Show (IxS sh i) where + showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + +instance Functor (IxS sh) where + fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) + +instance Foldable (IxS sh) where + foldMap f (IxS l) = listsFold (f . getConst) l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh +ixCvtXS ZSS ZIX = ZIS +ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx + +ixCvtSX :: IIxS sh -> IIxX (MapJust sh) +ixCvtSX ZIS = ZIX +ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh + + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ListS sh SNat) + deriving (Eq, Ord) + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS = ShS ZS + +pattern (:$$) + :: forall {sh1}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) + where i :$$ ShS shl = ShS (i ::$ shl) + +infixr 3 :$$ + +{-# COMPLETE ZSS, (:$$) #-} + +instance Show (ShS sh) where + showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + +shsLength :: ShS sh -> Int +shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) + +shsToList :: ShS sh -> [Int] +shsToList ZSS = [] +shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh + +shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh +shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = + castWith (subst1 (lem Refl)) $ + n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) + where + lem :: forall sh1 sh' n. + Just n : sh1 :~: MapJust sh' + -> n : Tail sh' :~: sh' + lem Refl = unsafeCoerceRefl +shCvtXS' (SUnknown _ :$% _) = error "impossible" + +shCvtSX :: ShS sh -> IShX (MapJust sh) +shCvtSX ZSS = ZSX +shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh + +shsSize :: ShS sh -> Int +shsSize ZSS = 1 +shsSize (n :$$ sh) = fromSNat' n * shsSize sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + + +-- | Untyped: length is checked at runtime. +instance KnownShS sh => IsList (ListS sh (Const i)) where + type Item (ListS sh (Const i)) = i + fromList topl = go (knownShS @sh) topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "IsList(ListS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listsToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = IxS . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList topl = ShS (go (knownShS @sh) topl) + where + go :: ShS sh' -> [Int] -> ListS sh' SNat + go ZSS [] = ZS + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = sn ::$ go sh is + | otherwise = error $ "IsList(ShS): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "IsList(ShS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shsToList diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs new file mode 100644 index 0000000..afa91eb --- /dev/null +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -0,0 +1,379 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Internal.Shaped where + +import Prelude hiding (mappend) + +import Control.DeepSeq (NFData) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.TypeLits + +import Data.Array.Mixed.XArray (XArray) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Internal.Mixed +import Data.Array.Nested.Internal.Shape + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a) +deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a) + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArray i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArray i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +arithPromoteShaped :: forall sh a. PrimElt a + => (forall shx. Mixed shx a -> Mixed shx a) + -> Shaped sh a -> Shaped sh a +arithPromoteShaped = coerce + +arithPromoteShaped2 :: forall sh a. PrimElt a + => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) + -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = arithPromoteShaped2 (+) + (-) = arithPromoteShaped2 (-) + (*) = arithPromoteShaped2 (*) + negate = arithPromoteShaped negate + abs = arithPromoteShaped abs + signum = arithPromoteShaped signum + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = arithPromoteShaped recip + (/) = arithPromoteShaped2 (/) + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = arithPromoteShaped exp + log = arithPromoteShaped log + sqrt = arithPromoteShaped sqrt + (**) = arithPromoteShaped2 (**) + logBase = arithPromoteShaped2 logBase + sin = arithPromoteShaped sin + cos = arithPromoteShaped cos + tan = arithPromoteShaped tan + asin = arithPromoteShaped asin + acos = arithPromoteShaped acos + atan = arithPromoteShaped atan + sinh = arithPromoteShaped sinh + cosh = arithPromoteShaped cosh + tanh = arithPromoteShaped tanh + asinh = arithPromoteShaped asinh + acosh = arithPromoteShaped acosh + atanh = arithPromoteShaped atanh + log1p = arithPromoteShaped GHC.Float.log1p + expm1 = arithPromoteShaped GHC.Float.expm1 + log1pexp = arithPromoteShaped GHC.Float.log1pexp + log1mexp = arithPromoteShaped GHC.Float.log1mexp + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shCvtXS' (mshape arr) + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (ixCvtSX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a + => ShS sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemMapJustTakeLen perm (sshape sarr) + , Refl <- lemMapJustDropLen perm (sshape sarr) + , Refl <- lemMapJustPermute perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 + +sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a +sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoList1 :: Elt a => Shaped '[n] a -> [a] +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh1) + , Refl <- lemMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) + (shCvtSX sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = + sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) + +sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(MapJust sh') + , Refl <- lemRankMapJust targetsh + = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs deleted file mode 100644 index c4fe066..0000000 --- a/src/Data/Array/Nested/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Nested.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemMapJustTakeLen PNil _ = Refl -lemMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustTakeLen is sh = Refl -lemMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemMapJustDropLen PNil _ = Refl -lemMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustDropLen is sh = Refl -lemMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" - -lemMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemMapJustIndex SZ (_ :$$ _) = Refl -lemMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemMapJustIndex i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemMapJustIndex _ ZSS = error "Index of empty" - -lemMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemMapJustPermute PNil _ = Refl -lemMapJustPermute (i `PCons` is) sh - | Refl <- lemMapJustPermute is sh - , Refl <- lemMapJustIndex i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs deleted file mode 100644 index 84e16b3..0000000 --- a/src/Data/Array/Nested/Mixed.hs +++ /dev/null @@ -1,741 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Mixed where - -import Control.DeepSeq (NFData) -import Control.Monad (forM_, when) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Coerce -import Data.Foldable (toList) -import Data.Int -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types (CInt) -import Foreign.Storable (Storable) -import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed (XArray(..)) -import Data.Array.Mixed qualified as X -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Lemmas - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class Storable a => PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Show, Eq, Generic) - --- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. -deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) - -instance NFData a => NFData (Mixed sh (Primitive a)) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector) --- etc. - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) -deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) -deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) -deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) -deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) -deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) -deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - -mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) - -mliftNumElt2 :: PrimElt a - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) - | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) - | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 numEltAdd - (-) = mliftNumElt2 numEltSub - (*) = mliftNumElt2 numEltMul - negate = mliftNumElt1 numEltNeg - abs = mliftNumElt1 numEltAbs - signum = mliftNumElt1 numEltSignum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 floatEltRecip - (/) = mliftNumElt2 floatEltDiv - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 floatEltExp - log = mliftNumElt1 floatEltLog - sqrt = mliftNumElt1 floatEltSqrt - - (**) = mliftNumElt2 floatEltPow - logBase = mliftNumElt2 floatEltLogbase - - sin = mliftNumElt1 floatEltSin - cos = mliftNumElt1 floatEltCos - tan = mliftNumElt1 floatEltTan - asin = mliftNumElt1 floatEltAsin - acos = mliftNumElt1 floatEltAcos - atan = mliftNumElt1 floatEltAtan - sinh = mliftNumElt1 floatEltSinh - cosh = mliftNumElt1 floatEltCosh - tanh = mliftNumElt1 floatEltTanh - asinh = mliftNumElt1 floatEltAsinh - acosh = mliftNumElt1 floatEltAcosh - atanh = mliftNumElt1 floatEltAtanh - log1p = mliftNumElt1 floatEltLog1p - expm1 = mliftNumElt1 floatEltExpm1 - log1pexp = mliftNumElt1 floatEltLog1pexp - log1mexp = mliftNumElt1 floatEltLog1mexp - - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - - mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - - -- ====== PRIVATE METHODS ====== -- - - -- | Tree giving the shape of every array component. - type ShapeTree a - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate' / 'rgenerate' / --- 'sgenerate'. -class Elt a => KnownElt a where - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IShX sh -> Mixed sh a - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive sh _) = sh - mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift ssh2 f (M_Primitive _ a) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , let result = f ZKX a - = M_Primitive (X.shape ssh2 result) result - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , Refl <- lemAppNil @sh3 - , let result = f ZKX a b - = M_Primitive (X.shape ssh3 result) result - - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - - mtranspose perm (M_Primitive sh arr) = - M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) - - type ShapeTree (Primitive a) = () - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where - memptyArray sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) - mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) - mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) - mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - - mcast ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) - - mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) - - type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - - mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest sh arr) i - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest ZSX - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - - mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) - in M_Nest sh2 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) - in M_Nest sh3 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcast ssh1 sh2 _ (M_Nest sh1T arr) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh (Mixed sh' a) - -> Mixed (PermutePrefix is sh) (Mixed sh' a) - mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') - , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) - , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') - , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (shxPermutePrefix perm sh) - (mtranspose perm arr) - - type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) - - mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - - mvecsUnsafeNew sh example - | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of - [] -> memptyArray sh - firstidx : restidxs -> - let firstelem = f (ixxZero' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = - let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 - where - sn :$% sh = mshape arr1 - sm :$% _ = mshape arr2 - ssh = ssxFromShape sh - snm :: SMayNat () SNat (AddMaybe n m) - snm = case (sn, sm) of - (SUnknown{}, _) -> SUnknown () - (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (snatPlus n m) - - f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b - f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a - => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') - (\(sshT :: StaticShX shT) -> - case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of - Refl -> X.replicate sh (ssxAppend ssh' sshT)) - arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = - let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') - arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - -masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -masXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -masXArrayPrim = masXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: PrimElt a - => (a -> a) - -> Mixed sh a -> Mixed sh a -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: PrimElt a - => (a -> a -> a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = - fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs deleted file mode 100644 index c2f9405..0000000 --- a/src/Data/Array/Nested/Ranked.hs +++ /dev/null @@ -1,446 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Ranked where - -import Prelude hiding (mappend) - -import Control.DeepSeq (NFData) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed (XArray(..)) -import Data.Array.Mixed qualified as X -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Nested.Mixed -import Data.Array.Nested.Shape - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a) -deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where - mshape (M_Ranked arr) = mshape arr - mindex (M_Ranked arr) i = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i = - coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - - mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoListOuter (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift ssh2 f (M_Ranked arr) = - coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = - coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 - - mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) - - mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - - type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where - memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (SNat @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - -arithPromoteRanked :: forall n a. PrimElt a - => (forall sh. Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -arithPromoteRanked = coerce - -arithPromoteRanked2 :: forall n a. PrimElt a - => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -> Ranked n a -arithPromoteRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where - (+) = arithPromoteRanked2 (+) - (-) = arithPromoteRanked2 (-) - (*) = arithPromoteRanked2 (*) - negate = arithPromoteRanked negate - abs = arithPromoteRanked abs - signum = arithPromoteRanked signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" - recip = arithPromoteRanked recip - (/) = arithPromoteRanked2 (/) - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" - exp = arithPromoteRanked exp - log = arithPromoteRanked log - sqrt = arithPromoteRanked sqrt - (**) = arithPromoteRanked2 (**) - logBase = arithPromoteRanked2 logBase - sin = arithPromoteRanked sin - cos = arithPromoteRanked cos - tan = arithPromoteRanked tan - asin = arithPromoteRanked asin - acos = arithPromoteRanked acos - atan = arithPromoteRanked atan - sinh = arithPromoteRanked sinh - cosh = arithPromoteRanked cosh - tanh = arithPromoteRanked tanh - asinh = arithPromoteRanked asinh - acosh = arithPromoteRanked acosh - atanh = arithPromoteRanked atanh - log1p = arithPromoteRanked GHC.Float.log1p - expm1 = arithPromoteRanked GHC.Float.expm1 - log1pexp = arithPromoteRanked GHC.Float.log1pexp - log1mexp = arithPromoteRanked GHC.Float.log1mexp - - -rshape :: forall n a. Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (castWith (subst2 (lemReplicatePlusApp (ixrToSNat idx) (Proxy @m) (Proxy @Nothing))) arr) - (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f - | sn@SNat <- shrToSNat sh - , Dict <- lemKnownReplicate sn - , Refl <- lemRankReplicate sn - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. Elt a - => SNat n2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) - --- | See the documentation of 'mlift2'. -rlift2 :: forall n1 n2 n3 a. Elt a - => SNat n3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a -rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) - -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) - -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive - -rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a -rtranspose perm arr - | sn@SNat <- shrToSNat (rshape arr) - , Dict <- lemKnownReplicate sn - , length perm <= fromIntegral (natVal (Proxy @n)) - = rlift sn - (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) - arr - | otherwise - = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" - -rappend :: forall n a. Elt a - => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend arr1 arr2 - | sn@SNat <- shrToSNat (rshape arr1) - , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) - arr1 arr2 - -rscalar :: Elt a => a -> Ranked 0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v - | Dict <- lemKnownReplicate (shrToSNat sh) - = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = rfromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: PrimElt a => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - -rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) - -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) - -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) - -rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a -rfromOrthotope sn arr - | Refl <- lemRankReplicate sn - = let xarr = XArray arr - in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) - -runScalar :: Elt a => Ranked 0 a -> a -runScalar arr = rindex arr ZIR - -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) - --- | If there is a zero-sized dimension in the @n@-prefix of the shape of the --- input array, then there is no way to deduce the full shape of the output --- array (more precisely, the @n2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the @n2@ part of the output shape with zeros. --- --- For example, if: --- --- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] --- f :: Ranked 2 Int -> Ranked 3 Float --- @ --- --- then: --- --- @ --- rrerank _ _ _ f arr :: Ranked 5 Float --- @ --- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank ssh sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr - -rreplicate :: forall n m a. Elt a - => IShR n -> Ranked m a -> Ranked (n + m) a -rreplicate sh (Ranked arr) - | Refl <- lemReplicatePlusApp (shrToSNat sh) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked (mreplicate (shCvtRX sh) arr) - -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x - | Dict <- lemKnownReplicate (shrToSNat sh) - = Ranked (mreplicateScalP (shCvtRX sh) x) - -rreplicateScal :: forall n a. PrimElt a - => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) - -rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (shrToSNat (rshape arr)) - (\_ -> X.sliceU i n) - arr - -rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (shrToSNat (rshape arr)) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr - -rreshape :: forall n n' a. Elt a - => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' rarr@(Ranked arr) - | Dict <- lemKnownReplicate (shrToSNat (rshape rarr)) - , Dict <- lemKnownReplicate (shrToSNat sh') - = Ranked (mreshape (shCvtRX sh') arr) - -riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a -riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota - -rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) - -rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) - -rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a -rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) - -rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) -rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) - , Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh diff --git a/src/Data/Array/Nested/Shape.hs b/src/Data/Array/Nested/Shape.hs deleted file mode 100644 index 774b4bd..0000000 --- a/src/Data/Array/Nested/Shape.hs +++ /dev/null @@ -1,467 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Shape where - -import Data.Array.Mixed.Types -import Data.Coerce (coerce) -import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Kind (Type, Constraint) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import GHC.IsList (IsList) -import GHC.IsList qualified as IsList -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape - - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - -listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR sh' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrToSNat :: ListR n i -> SNat n -listrToSNat ZR = SNat -listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrToSNat sperm, listrToSNat sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l - -ixrZero :: SNat n -> IIxR n -ixrZero SZ = ZIR -ixrZero (SS n) = 0 :.: ixrZero n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - -ixrToSNat :: IxR n i -> SNat n -ixrToSNat (IxR sh) = listrToSNat sh - -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: (ShR sh) = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) - --- | The number of elements in an array described by this shape. -shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh - -shrToSNat :: ShR n i -> SNat n -shrToSNat (ShR sh) = listrToSNat sh - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList = go (SNat @n) - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error "IsList(ListR): Mismatched list length" - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where - type Item (IxR n i) = i - fromList = IxR . IsList.fromList - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList - - -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) -infixr 3 ::$ - -instance (forall n. Show (f n)) => Show (ListS sh f) where - showsPrec _ = listsShow shows - -data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) -listsUncons ZS = Nothing - -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS -listsShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListS sh' f -> ShowS - go _ ZS = id - go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is - -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f -listsAppend ZS idx' = idx' -listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' - -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f -listsTakeLenPerm PNil _ = ZS -listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh -listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLenPerm PNil sh = sh -listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh -listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh - --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" - -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) - -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) - -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) - -applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f -applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) - -applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -applyPermIxS = coerce (applyPermS @(Const i)) - -applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -applyPermShS = coerce (applyPermS @SNat) - - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) - :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) - => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) -infixr 3 :.$ - -{-# COMPLETE ZIS, (:.$) #-} - -type IIxS sh = IxS sh Int - -instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l - -instance Functor (IxS sh) where - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l - -ixsZero :: ShS sh -> IIxS sh -ixsZero ZSS = ZIS -ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - - --- | The shape of a shape-typed array given as a list of 'SNat' values. -type role ShS nominal -type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord) - -pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS - -pattern (:$$) - :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) - => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) - -infixr 3 :$$ - -{-# COMPLETE ZSS, (:$$) #-} - -instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l - -shsLength :: ShS sh -> Int -shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) - -shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh - -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - -shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShS :: [Nat] -> Constraint -class KnownShS sh where knownShS :: ShS sh -instance KnownShS '[] where knownShS = ZSS -instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS - - --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listsToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShS sh => IsList (IxS sh i) where - type Item (IxS sh i) = i - fromList = IxS . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and values are checked at runtime. -instance KnownShS sh => IsList (ShS sh) where - type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = shsToList diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs deleted file mode 100644 index 934433e..0000000 --- a/src/Data/Array/Nested/Shaped.hs +++ /dev/null @@ -1,379 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Shaped where - -import Prelude hiding (mappend) - -import Control.DeepSeq (NFData) -import Control.Monad.ST -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) -import GHC.TypeLits - -import Data.Array.Mixed (XArray) -import Data.Array.Mixed qualified as X -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Nested.Lemmas -import Data.Array.Nested.Mixed -import Data.Array.Nested.Shape - - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a) -deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) -deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) - -instance Elt a => Elt (Shaped sh a) where - mshape (M_Shaped arr) = mshape arr - mindex (M_Shaped arr) i = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - - mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoListOuter (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift ssh2 f (M_Shaped arr) = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = - coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 ssh3 f arr1 arr2 - - mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) - - mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - - type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where - memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - -arithPromoteShaped :: forall sh a. PrimElt a - => (forall shx. Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -arithPromoteShaped = coerce - -arithPromoteShaped2 :: forall sh a. PrimElt a - => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -> Shaped sh a -arithPromoteShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where - (+) = arithPromoteShaped2 (+) - (-) = arithPromoteShaped2 (-) - (*) = arithPromoteShaped2 (*) - negate = arithPromoteShaped negate - abs = arithPromoteShaped abs - signum = arithPromoteShaped signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" - recip = arithPromoteShaped recip - (/) = arithPromoteShaped2 (/) - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" - exp = arithPromoteShaped exp - log = arithPromoteShaped log - sqrt = arithPromoteShaped sqrt - (**) = arithPromoteShaped2 (**) - logBase = arithPromoteShaped2 logBase - sin = arithPromoteShaped sin - cos = arithPromoteShaped cos - tan = arithPromoteShaped tan - asin = arithPromoteShaped asin - acos = arithPromoteShaped acos - atan = arithPromoteShaped atan - sinh = arithPromoteShaped sinh - cosh = arithPromoteShaped cosh - tanh = arithPromoteShaped tanh - asinh = arithPromoteShaped asinh - acosh = arithPromoteShaped acosh - atanh = arithPromoteShaped atanh - log1p = arithPromoteShaped GHC.Float.log1p - expm1 = arithPromoteShaped GHC.Float.expm1 - log1pexp = arithPromoteShaped GHC.Float.log1pexp - log1mexp = arithPromoteShaped GHC.Float.log1mexp - - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a - => ShS sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a - => ShS sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) - => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - , Refl <- lemMapJustTakeLen perm (sshape sarr) - , Refl <- lemMapJustDropLen perm (sshape sarr) - , Refl <- lemMapJustPermute perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) - = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 - -sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh1) - , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = - let _ :$$ sh = sshape arr - in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - -sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) - -sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(MapJust sh') - , Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) -- cgit v1.2.3-70-g09d2