diff options
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 5 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 331 | 
2 files changed, 336 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 83a5ee4..6ff3bdc 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -61,6 +61,11 @@ permToList (x `PCons` l) = TN.fromSNat x : permToList l  permToList' :: Perm list -> [Int]  permToList' = map fromIntegral . permToList +-- | Utility class for generating permutations from type class information. +class KnownPerm l where makePerm :: Perm l +instance KnownPerm '[] where makePerm = PNil +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +  -- ** Applying permutations diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs new file mode 100644 index 0000000..cc0a6a5 --- /dev/null +++ b/src/Data/Array/Mixed/XArray.hs @@ -0,0 +1,331 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.XArray where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +type XArray :: [Maybe Nat] -> Type -> Type +newtype XArray sh a = XArray (S.Array (Rank sh) a) +  deriving (Show, Eq, Generic) + +-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. +deriving instance (Ord a, Storable a) => Ord (XArray '[] a) + +instance NFData a => NFData (XArray sh a) + + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) +  where +    go :: StaticShX sh' -> [Int] -> IShX sh' +    go ZKX [] = ZSX +    go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l +    go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a +fromVector sh v +  | Dict <- lemKnownNatRank sh +  = XArray (S.fromVector (shxToList sh) v) + +toVector :: Storable a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr + +scalar :: Storable a => a -> XArray '[] a +scalar = XArray . S.scalar + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 +     => StaticShX sh1 -> IShX sh2 -> StaticShX sh' +     -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) +  | Refl <- lemRankApp ssh1 ssh' +  , Refl <- lemRankApp (ssxFromShape sh2) ssh' +  = let arrsh :: IShX sh1 +        (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) +    in if shxToList arrsh == shxToList sh2 +         then XArray arr +         else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + +unScalar :: Storable a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a + +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) +  | Dict <- lemKnownNatRankSSX ssh' +  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') +  , Refl <- lemRankApp (ssxFromShape sh) ssh' +  = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ +            S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ +              arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x +  | Dict <- lemKnownNatRank sh +  = XArray (S.constant (shxToList sh) x) + +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) + +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +--   XArray . S.fromVector (shxShapeL sh) +--     <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) + +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a +indexPartial (XArray arr) ZIX = XArray arr +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx + +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a +index xarr i +  | Refl <- lemAppNil @sh +  = let XArray arr' = indexPartial xarr i :: XArray '[] a +    in S.unScalar arr' + +append :: forall n m sh a. Storable a +       => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) +  | Dict <- lemKnownNatRankSSX ssh +  = XArray (S.append a b) + +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int   -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +rerank :: forall sh sh1 sh2 a b. +          (Storable a, Storable b) +       => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 +       -> (XArray sh1 a -> XArray sh2 b) +       -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f xarr@(XArray arr) +  | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) +  = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) +    in if any (== 0) (shxToList sh) +         then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) +         else case () of +           () | Dict <- lemKnownNatRankSSX ssh +              , Dict <- lemKnownNatRankSSX ssh2 +              , Refl <- lemRankApp ssh ssh1 +              , Refl <- lemRankApp ssh ssh2 +              -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) +                          (\a -> let XArray r = f (XArray a) in r) +                          arr) + +rerankTop :: forall sh1 sh2 sh a b. +             (Storable a, Storable b) +          => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh +          -> (XArray sh1 a -> XArray sh2 b) +          -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + +-- | The caveat about empty arrays at @rerank@ applies here too. +rerank2 :: forall sh sh1 sh2 a b c. +           (Storable a, Storable b, Storable c) +        => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 +        -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) +        -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) +  | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) +  = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) +    in if any (== 0) (shxToList sh) +         then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) +         else case () of +           () | Dict <- lemKnownNatRankSSX ssh +              , Dict <- lemKnownNatRankSSX ssh2 +              , Refl <- lemRankApp ssh ssh1 +              , Refl <- lemRankApp ssh ssh2 +              -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) +                          (\a b -> let XArray r = f (XArray a) (XArray b) in r) +                          arr1 arr2) + +-- | The list argument gives indices into the original dimension list. +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) +          => StaticShX sh +          -> Perm is +          -> XArray sh a +          -> XArray (PermutePrefix is sh) a +transpose ssh perm (XArray arr) +  | Dict <- lemKnownNatRankSSX ssh +  , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) +  , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm +  , Refl <- lemRankDropLen ssh perm +  = XArray (S.transpose (permToList' perm) arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. +                    SNat n -> StaticShX sh -> [Int] +                 -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) +  | length perm <= fromSNat' sn +  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) +  = XArray (S.transpose perm arr) +  | otherwise +  = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" + +transpose2 :: forall sh1 sh2 a. +              StaticShX sh1 -> StaticShX sh2 +           -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) +  | Refl <- lemRankApp ssh1 ssh2 +  , Refl <- lemRankApp ssh2 ssh1 +  , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) +  , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) +  , Refl <- lemRankAppComm ssh1 ssh2 +  , let n1 = ssxLength ssh1 +  = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (Storable a, NumElt a) => XArray sh a -> a +sumFull (XArray arr) = +  S.unScalar $ +    numEltSum1Inner (SNat @0) $ +      S.fromVector [product (S.shapeL arr)] $ +        S.toVector arr + +sumInner :: forall sh sh' a. (Storable a, NumElt a) +         => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' arr +  | Refl <- lemAppNil @sh +  = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) +        sh'F = shxFlatten sh' :$% ZSX +        ssh'F = ssxFromShape sh'F + +        go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a +        go (XArray arr') +          | Refl <- lemRankApp ssh ssh'F +          , let sn = listxLengthSNat (let StaticShX l = ssh in l) +          = XArray (numEltSum1Inner sn arr') + +    in go $ +       transpose2 ssh'F ssh $ +       reshapePartial ssh' ssh sh'F $ +       transpose2 ssh ssh' $ +         arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) +         => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' arr +  | Refl <- lemAppNil @sh +  = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) +        shF = shxFlatten sh :$% ZSX +    in sumInner ssh' (ssxFromShape shF) $ +       transpose2 (ssxFromShape shF) ssh' $ +       reshapePartial ssh ssh' shF $ +         arr + +fromListOuter :: forall n sh a. Storable a +              => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l +  | Dict <- lemKnownNatRankSSX ssh +  = case ssh of +      SKnown m :!% _ | fromSNat' m /= length l -> +        error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ +                "does not match the type (" ++ show (fromSNat' m) ++ ")" +      _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = +  case S.shapeL arr of +    0 : _ -> [] +    _ -> coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = +  let n = length l +  in case ssh of +       SKnown m :!% _ | fromSNat' m /= n -> +         error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ +                 "does not match the type (" ++ show (fromSNat' m) ++ ")" +       _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr + +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh +  | Dict <- lemKnownNatRank sh +  = XArray (S.constant (shxToList sh) +                       (error "Data.Array.Mixed.empty: shape was not empty")) + +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) + +rev1 :: XArray (n : sh) a -> XArray (n : sh) a +rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) +  | Dict <- lemKnownNatRankSSX ssh1 +  , Dict <- lemKnownNatRank sh2 +  = XArray (S.reshape (shxToList sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) +  | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') +  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') +  = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) | 
