diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 22:47:52 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 22:47:52 +0200 |
commit | 8b59d8ef4ff97936f2a753d1ce345e0404c26b2b (patch) | |
tree | 947f75cb43982fbdb551dc329f036b0591f3c2b2 /src/Data/Array/Mixed | |
parent | f0752d67cd188f438280e1f0c692dc1f5f14a190 (diff) |
Clearer module purposes
Thanks Mikolaj for discussion
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 5 | ||||
-rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 331 |
2 files changed, 336 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 83a5ee4..6ff3bdc 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -61,6 +61,11 @@ permToList (x `PCons` l) = TN.fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList +-- | Utility class for generating permutations from type class information. +class KnownPerm l where makePerm :: Perm l +instance KnownPerm '[] where makePerm = PNil +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm + -- ** Applying permutations diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs new file mode 100644 index 0000000..cc0a6a5 --- /dev/null +++ b/src/Data/Array/Mixed/XArray.hs @@ -0,0 +1,331 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.XArray where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +type XArray :: [Maybe Nat] -> Type -> Type +newtype XArray sh a = XArray (S.Array (Rank sh) a) + deriving (Show, Eq, Generic) + +-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. +deriving instance (Ord a, Storable a) => Ord (XArray '[] a) + +instance NFData a => NFData (XArray sh a) + + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + = XArray (S.fromVector (shxToList sh) v) + +toVector :: Storable a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr + +scalar :: Storable a => a -> XArray '[] a +scalar = XArray . S.scalar + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> StaticShX sh' + -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) + | Refl <- lemRankApp ssh1 ssh' + , Refl <- lemRankApp (ssxFromShape sh2) ssh' + = let arrsh :: IShX sh1 + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in if shxToList arrsh == shxToList sh2 + then XArray arr + else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + +unScalar :: Storable a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a + +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') + , Refl <- lemRankApp (ssxFromShape sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) x) + +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) + +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) + +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a +indexPartial (XArray arr) ZIX = XArray arr +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx + +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in S.unScalar arr' + +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.append a b) + +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +rerank :: forall sh sh1 sh2 a b. + (Storable a, Storable b) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f xarr@(XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a -> let XArray r = f (XArray a) in r) + arr) + +rerankTop :: forall sh1 sh2 sh a b. + (Storable a, Storable b) + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + +-- | The caveat about empty arrays at @rerank@ applies here too. +rerank2 :: forall sh sh1 sh2 a b c. + (Storable a, Storable b, Storable c) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a b -> let XArray r = f (XArray a) (XArray b) in r) + arr1 arr2) + +-- | The list argument gives indices into the original dimension list. +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) + => StaticShX sh + -> Perm is + -> XArray sh a + -> XArray (PermutePrefix is sh) a +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen ssh perm + = XArray (S.transpose (permToList' perm) arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) + = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" + +transpose2 :: forall sh1 sh2 a. + StaticShX sh1 -> StaticShX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (Storable a, NumElt a) => XArray sh a -> a +sumFull (XArray arr) = + S.unScalar $ + numEltSum1Inner (SNat @0) $ + S.fromVector [product (S.shapeL arr)] $ + S.toVector arr + +sumInner :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' arr + | Refl <- lemAppNil @sh + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShape sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = listxLengthSNat (let StaticShX l = ssh in l) + = XArray (numEltSum1Inner sn arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' arr + | Refl <- lemAppNil @sh + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShape shF) $ + transpose2 (ssxFromShape shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr + +fromListOuter :: forall n sh a. Storable a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l + | Dict <- lemKnownNatRankSSX ssh + = case ssh of + SKnown m :!% _ | fromSNat' m /= length l -> + error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = + case S.shapeL arr of + 0 : _ -> [] + _ -> coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = + let n = length l + in case ssh of + SKnown m :!% _ | fromSNat' m /= n -> + error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr + +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) + (error "Data.Array.Mixed.empty: shape was not empty")) + +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) + +rev1 :: XArray (n : sh) a -> XArray (n : sh) a +rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shxToList sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) |