diff options
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 335 |
1 files changed, 0 insertions, 335 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs deleted file mode 100644 index 4a338a2..0000000 --- a/src/Data/Array/Mixed.hs +++ /dev/null @@ -1,335 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed where - -import Control.DeepSeq (NFData(..)) -import Data.Array.Ranked qualified as ORB -import Data.Array.RankedS qualified as S -import Data.Coerce -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import Data.Type.Ord -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types - - -type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (Rank sh) a) - deriving (Show, Eq, Generic) - --- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. -deriving instance (Ord a, Storable a) => Ord (XArray '[] a) - -instance NFData a => NFData (XArray sh a) - - -shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh -shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKX [] = ZSX - go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shxToList sh) v) - -toVector :: Storable a => XArray sh a -> VS.Vector a -toVector (XArray arr) = S.toVector arr - -scalar :: Storable a => a -> XArray '[] a -scalar = XArray . S.scalar - --- | Will throw if the array does not have the casted-to shape. -cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> StaticShX sh' - -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -cast ssh1 sh2 ssh' (XArray arr) - | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (ssxFromShape sh2) ssh' - = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if shxToList arrsh == shxToList sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" - -unScalar :: Storable a => XArray '[] a -> a -unScalar (XArray a) = S.unScalar a - -replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a -replicate sh ssh' (XArray arr) - | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') - , Refl <- lemRankApp (ssxFromShape sh) ssh' - = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ - arr) - -replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicateScal sh x - | Dict <- lemKnownNatRank sh - = XArray (S.constant (shxToList sh) x) - -generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) - --- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shxShapeL sh) --- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) - -indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a -indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx - -index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' - -append :: forall n m sh a. Storable a - => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append ssh (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX ssh - = XArray (S.append a b) - --- | If the prefix of the shape of the input array (@sh@) is empty (i.e. --- contains a zero), then there is no way to deduce the full shape of the output --- array (more precisely, the @sh2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the shape with zeros wherever we cannot deduce --- what it should be. --- --- For example, if: --- --- @ --- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] --- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float --- @ --- --- then: --- --- @ --- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float --- @ --- --- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ --- in this shape: we don't know if @f@ intended to return an array with shape 0 --- here (it probably didn't), but there is no better number to put here absent --- a subarray of the input to pass to @f@. --- --- In this particular case the fact that @sh@ is empty was evident from the --- type-level information, but the same situation occurs when @sh@ consists of --- @Nothing@s, and some of those happen to be zero at runtime. -rerank :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f xarr@(XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if any (== 0) (shxToList sh) - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a -> let XArray r = f (XArray a) in r) - arr) - -rerankTop :: forall sh1 sh2 sh a b. - (Storable a, Storable b) - => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b -rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh - --- | The caveat about empty arrays at @rerank@ applies here too. -rerank2 :: forall sh sh1 sh2 a b c. - (Storable a, Storable b, Storable c) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if any (== 0) (shxToList sh) - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a b -> let XArray r = f (XArray a) (XArray b) in r) - arr1 arr2) - -class KnownNatList l where makeNatList :: Perm l -instance KnownNatList '[] where makeNatList = PNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList - --- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) - => StaticShX sh - -> Perm is - -> XArray sh a - -> XArray (PermutePrefix is sh) a -transpose ssh perm (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) - , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm - , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList' perm) arr) - --- | The list argument gives indices into the original dimension list. --- --- The permutation (the list) must have length <= @n@. If it is longer, this --- function throws. -transposeUntyped :: forall n sh a. - SNat n -> StaticShX sh -> [Int] - -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a -transposeUntyped sn ssh perm (XArray arr) - | length perm <= fromSNat' sn - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) - = XArray (S.transpose perm arr) - | otherwise - = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" - -transpose2 :: forall sh1 sh2 a. - StaticShX sh1 -> StaticShX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (Storable a, NumElt a) => XArray sh a -> a -sumFull (XArray arr) = - S.unScalar $ - numEltSum1Inner (SNat @0) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr - -sumInner :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' arr - | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = shxFlatten sh' :$% ZSX - ssh'F = ssxFromShape sh'F - - go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a - go (XArray arr') - | Refl <- lemRankApp ssh ssh'F - , let sn = listxLengthSNat (let StaticShX l = ssh in l) - = XArray (numEltSum1Inner sn arr') - - in go $ - transpose2 ssh'F ssh $ - reshapePartial ssh' ssh sh'F $ - transpose2 ssh ssh' $ - arr - -sumOuter :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' arr - | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = shxFlatten sh :$% ZSX - in sumInner ssh' (ssxFromShape shF) $ - transpose2 (ssxFromShape shF) ssh' $ - reshapePartial ssh ssh' shF $ - arr - -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh - = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) - -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = - case S.shapeL arr of - 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) - -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) - -toList1 :: Storable a => XArray '[n] a -> [a] -toList1 (XArray arr) = S.toList arr - --- | Throws if the given shape is not, in fact, empty. -empty :: forall sh a. Storable a => IShX sh -> XArray sh a -empty sh - | Dict <- lemKnownNatRank sh - = XArray (S.constant (shxToList sh) - (error "Data.Array.Mixed.empty: shape was not empty")) - -slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a -slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) - -sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a -sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) - -rev1 :: XArray (n : sh) a -> XArray (n : sh) a -rev1 (XArray arr) = XArray (S.rev [0] arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a -reshape ssh1 sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX ssh1 - , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shxToList sh2) arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -reshapePartial ssh1 ssh' sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') - = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) - --- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). -iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a -iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) |