From 3c8f13c8310de646b15c6f2745cfe190db7610db Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 14 May 2025 19:43:21 +0200 Subject: Move Arith, XArray and Convert --- bench/Main.hs | 2 +- ox-arrays.cabal | 6 +- src/Data/Array/Arith.hs | 43 ++++ src/Data/Array/Mixed/Internal/Arith.hs | 43 ---- src/Data/Array/Mixed/XArray.hs | 348 ------------------------------ src/Data/Array/Nested.hs | 2 +- src/Data/Array/Nested/Convert.hs | 86 ++++++++ src/Data/Array/Nested/Internal/Convert.hs | 86 -------- src/Data/Array/Nested/Mixed.hs | 6 +- src/Data/Array/Nested/Ranked.hs | 4 +- src/Data/Array/Nested/Shaped.hs | 4 +- src/Data/Array/XArray.hs | 348 ++++++++++++++++++++++++++++++ 12 files changed, 489 insertions(+), 489 deletions(-) create mode 100644 src/Data/Array/Arith.hs delete mode 100644 src/Data/Array/Mixed/Internal/Arith.hs delete mode 100644 src/Data/Array/Mixed/XArray.hs create mode 100644 src/Data/Array/Nested/Convert.hs delete mode 100644 src/Data/Array/Nested/Internal/Convert.hs create mode 100644 src/Data/Array/XArray.hs diff --git a/bench/Main.hs b/bench/Main.hs index e5ef4d1..ef03b1a 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -15,7 +15,7 @@ import Numeric.LinearAlgebra qualified as LA import Test.Tasty.Bench import Text.Show (showListWith) -import Data.Array.Mixed.XArray (XArray(..)) +import Data.Array.XArray (XArray(..)) import Data.Array.Nested import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) import Data.Array.Nested.Ranked (liftRanked1, liftRanked2) diff --git a/ox-arrays.cabal b/ox-arrays.cabal index f62f961..9cfc7dd 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -47,19 +47,19 @@ library -- put this module on top so ghci considers it the "main" module Data.Array.Nested - Data.Array.Mixed.Internal.Arith + Data.Array.Arith Data.Array.Mixed.Lemmas Data.Array.Mixed.Permutation Data.Array.Mixed.Types - Data.Array.Mixed.XArray - Data.Array.Nested.Internal.Convert Data.Array.Nested.Internal.Lemmas + Data.Array.Nested.Convert Data.Array.Nested.Mixed Data.Array.Nested.Ranked Data.Array.Nested.Shaped Data.Array.Nested.Mixed.Shape Data.Array.Nested.Ranked.Shape Data.Array.Nested.Shaped.Shape + Data.Array.XArray Data.Bag if flag(trace-wrappers) diff --git a/src/Data/Array/Arith.hs b/src/Data/Array/Arith.hs new file mode 100644 index 0000000..1eae737 --- /dev/null +++ b/src/Data/Array/Arith.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Array.Arith ( + module Data.Array.Arith, + module Data.Array.Strided.Arith, +) where + +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS + +import Data.Array.Strided qualified as AS +import Data.Array.Strided.Arith + +-- for liftVEltwise1 +import Data.Array.Strided.Arith.Internal (stridesDense) +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec + +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) + +liftO1 :: (AS.Array n a -> AS.Array n' b) + -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO + +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) + -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs deleted file mode 100644 index ebda388..0000000 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ /dev/null @@ -1,43 +0,0 @@ -{-# LANGUAGE ImportQualifiedPost #-} -module Data.Array.Mixed.Internal.Arith ( - module Data.Array.Mixed.Internal.Arith, - module Data.Array.Strided.Arith, -) where - -import Data.Array.Internal qualified as OI -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS - -import Data.Array.Strided qualified as AS -import Data.Array.Strided.Arith - --- for liftVEltwise1 -import Data.Array.Strided.Arith.Internal (stridesDense) -import Data.Vector.Storable qualified as VS -import Foreign.Storable -import GHC.TypeLits - - -fromO :: RS.Array n a -> AS.Array n a -fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec - -toO :: AS.Array n a -> RS.Array n a -toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) - -liftO1 :: (AS.Array n a -> AS.Array n' b) - -> RS.Array n a -> RS.Array n' b -liftO1 f = toO . f . fromO - -liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) - -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c -liftO2 f x y = toO (f (fromO x) (fromO y)) - -liftVEltwise1 :: (Storable a, Storable b) - => SNat n - -> (VS.Vector a -> VS.Vector b) - -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just (blockOff, blockSz) <- stridesDense sh offset strides = - let vec' = f (VS.slice blockOff blockSz vec) - in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs deleted file mode 100644 index 502d5d9..0000000 --- a/src/Data/Array/Mixed/XArray.hs +++ /dev/null @@ -1,348 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.XArray where - -import Control.DeepSeq (NFData) -import Data.Array.Internal qualified as OI -import Data.Array.Internal.RankedG qualified as ORG -import Data.Array.Internal.RankedS qualified as ORS -import Data.Array.Ranked qualified as ORB -import Data.Array.RankedS qualified as S -import Data.Coerce -import Data.Foldable (toList) -import Data.Kind -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Type.Ord -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Nested.Mixed.Shape - - -type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (Rank sh) a) - deriving (Show, Eq, Ord, Generic) - -instance NFData (XArray sh a) - - -shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh -shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKX [] = ZSX - go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shxToList sh) v) - -toVector :: Storable a => XArray sh a -> VS.Vector a -toVector (XArray arr) = S.toVector arr - --- | This allows observing the strides in the underlying orthotope array. This --- can be useful for optimisation, but should be considered an implementation --- detail: strides may change in new versions of this library without notice. -arrayStrides :: XArray sh a -> [Int] -arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides - -scalar :: Storable a => a -> XArray '[] a -scalar = XArray . S.scalar - --- | Will throw if the array does not have the casted-to shape. -cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> StaticShX sh' - -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -cast ssh1 sh2 ssh' (XArray arr) - | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (ssxFromShape sh2) ssh' - = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if shxToList arrsh == shxToList sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" - -unScalar :: Storable a => XArray '[] a -> a -unScalar (XArray a) = S.unScalar a - -replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a -replicate sh ssh' (XArray arr) - | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') - , Refl <- lemRankApp (ssxFromShape sh) ssh' - = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) - arr) - -replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicateScal sh x - | Dict <- lemKnownNatRank sh - = XArray (S.constant (shxToList sh) x) - -generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) - --- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shxShapeL sh) --- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) - -indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a -indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx - -index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' - -append :: forall n m sh a. Storable a - => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append ssh (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX ssh - = XArray (S.append a b) - --- | All arrays must have the same shape, except possibly for the outermost --- dimension. -concat :: Storable a - => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a -concat ssh l - | Dict <- lemKnownNatRankSSX ssh - = XArray (S.concatOuter (coerce (toList l))) - --- | If the prefix of the shape of the input array (@sh@) is empty (i.e. --- contains a zero), then there is no way to deduce the full shape of the output --- array (more precisely, the @sh2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the shape with zeros wherever we cannot deduce --- what it should be. --- --- For example, if: --- --- @ --- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] --- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float --- @ --- --- then: --- --- @ --- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float --- @ --- --- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ --- in this shape: we don't know if @f@ intended to return an array with shape 0 --- here (it probably didn't), but there is no better number to put here absent --- a subarray of the input to pass to @f@. --- --- In this particular case the fact that @sh@ is empty was evident from the --- type-level information, but the same situation occurs when @sh@ consists of --- @Nothing@s, and some of those happen to be zero at runtime. -rerank :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f xarr@(XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if 0 `elem` shxToList sh - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a -> let XArray r = f (XArray a) in r) - arr) - -rerankTop :: forall sh1 sh2 sh a b. - (Storable a, Storable b) - => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b -rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh - --- | The caveat about empty arrays at @rerank@ applies here too. -rerank2 :: forall sh sh1 sh2 a b c. - (Storable a, Storable b, Storable c) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if 0 `elem` shxToList sh - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a b -> let XArray r = f (XArray a) (XArray b) in r) - arr1 arr2) - --- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) - => StaticShX sh - -> Perm is - -> XArray sh a - -> XArray (PermutePrefix is sh) a -transpose ssh perm (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) - , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm - , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList' perm) arr) - --- | The list argument gives indices into the original dimension list. --- --- The permutation (the list) must have length <= @n@. If it is longer, this --- function throws. -transposeUntyped :: forall n sh a. - SNat n -> StaticShX sh -> [Int] - -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a -transposeUntyped sn ssh perm (XArray arr) - | length perm <= fromSNat' sn - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) - = XArray (S.transpose perm arr) - | otherwise - = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" - -transpose2 :: forall sh1 sh2 a. - StaticShX sh1 -> StaticShX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a -sumFull _ (XArray arr) = - S.unScalar $ - liftO1 (numEltSum1Inner (SNat @0)) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr - -sumInner :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' arr - | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = shxFlatten sh' :$% ZSX - ssh'F = ssxFromShape sh'F - - go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a - go (XArray arr') - | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) - = XArray (liftO1 (numEltSum1Inner sn) arr') - - in go $ - transpose2 ssh'F ssh $ - reshapePartial ssh' ssh sh'F $ - transpose2 ssh ssh' $ - arr - -sumOuter :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' arr - | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = shxFlatten sh :$% ZSX - in sumInner ssh' (ssxFromShape shF) $ - transpose2 (ssxFromShape shF) ssh' $ - reshapePartial ssh ssh' shF $ - arr - -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh - = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) - -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = - case S.shapeL arr of - 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) - -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) - -toList1 :: Storable a => XArray '[n] a -> [a] -toList1 (XArray arr) = S.toList arr - --- | Throws if the given shape is not, in fact, empty. -empty :: forall sh a. Storable a => IShX sh -> XArray sh a -empty sh - | Dict <- lemKnownNatRank sh - , shxSize sh == 0 - = XArray (S.fromVector (shxToList sh) VS.empty) - | otherwise - = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh - -slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a -slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) - -sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a -sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) - -rev1 :: XArray (n : sh) a -> XArray (n : sh) a -rev1 (XArray arr) = XArray (S.rev [0] arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a -reshape ssh1 sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX ssh1 - , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shxToList sh2) arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -reshapePartial ssh1 ssh' sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') - = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) - --- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). -iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a -iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index af195ee..9faf6d7 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -105,7 +105,7 @@ import Prelude hiding (mappend, mconcat) import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Convert +import Data.Array.Nested.Convert import Data.Array.Nested.Mixed import Data.Array.Nested.Ranked import Data.Array.Nested.Shaped diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..639f5fd --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Convert where + +import Control.Category +import Data.Proxy +import Data.Type.Equality + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped +import Data.Array.Nested.Shaped.Shape + + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped arr targetsh + +-- | The only constructor that performs runtime shape checking is 'CastXS''. +-- For the other construtors, the types ensure that the shapes are already +-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. +data Castable a b where + CastId :: Castable a a + CastCmp :: Castable b c -> Castable a b -> Castable a c + + CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) + CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) + + CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) + CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) + CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' + -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) + + CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) + CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) + CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) + +instance Category Castable where + id = CastId + (.) = CastCmp + +castCastable :: (Elt a, Elt b) => Castable a b -> a -> b +castCastable = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the casting happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and castings when re-nesting back up. + go :: Castable a b -> Mixed esh a -> Mixed esh b + go CastId x = x + go (CastCmp c1 c2) x = go c1 (go c2 x) + go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) + go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) + go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = + M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy + (go c x))) + go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) + go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) + | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) + (go c x))) + go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) + + lemRankAppMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs deleted file mode 100644 index 611b45e..0000000 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ /dev/null @@ -1,86 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Convert where - -import Control.Category -import Data.Proxy -import Data.Type.Equality - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shaped -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mtoRanked arr - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) - , Refl <- lemRankMapJust targetsh - = mcastToShaped arr targetsh - --- | The only constructor that performs runtime shape checking is 'CastXS''. --- For the other construtors, the types ensure that the shapes are already --- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) - where - -- The 'esh' is the extension shape: the casting happens under a whole - -- bunch of additional dimensions that it does not touch. These dimensions - -- are 'esh'. - -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = - M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy - (go c x))) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) - (go c x))) - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - - lemRankAppMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 50a1b71..ec19c21 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -42,12 +42,12 @@ import GHC.Generics (Generic) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed.Internal.Arith +import Data.Array.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X import Data.Array.Nested.Mixed.Shape import Data.Bag diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index fb5caa9..e2074ac 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -41,8 +41,8 @@ import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index ba767cd..4bccbc4 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -41,8 +41,8 @@ import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs new file mode 100644 index 0000000..d8d564e --- /dev/null +++ b/src/Data/Array/XArray.hs @@ -0,0 +1,348 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.XArray where + +import Control.DeepSeq (NFData) +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Foldable (toList) +import Data.Kind +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Types +import Data.Array.Nested.Mixed.Shape + + +type XArray :: [Maybe Nat] -> Type -> Type +newtype XArray sh a = XArray (S.Array (Rank sh) a) + deriving (Show, Eq, Ord, Generic) + +instance NFData (XArray sh a) + + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + = XArray (S.fromVector (shxToList sh) v) + +toVector :: Storable a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr + +-- | This allows observing the strides in the underlying orthotope array. This +-- can be useful for optimisation, but should be considered an implementation +-- detail: strides may change in new versions of this library without notice. +arrayStrides :: XArray sh a -> [Int] +arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides + +scalar :: Storable a => a -> XArray '[] a +scalar = XArray . S.scalar + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> StaticShX sh' + -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) + | Refl <- lemRankApp ssh1 ssh' + , Refl <- lemRankApp (ssxFromShape sh2) ssh' + = let arrsh :: IShX sh1 + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in if shxToList arrsh == shxToList sh2 + then XArray arr + else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + +unScalar :: Storable a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a + +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') + , Refl <- lemRankApp (ssxFromShape sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) x) + +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) + +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) + +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a +indexPartial (XArray arr) ZIX = XArray arr +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx + +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in S.unScalar arr' + +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.append a b) + +-- | All arrays must have the same shape, except possibly for the outermost +-- dimension. +concat :: Storable a + => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a +concat ssh l + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.concatOuter (coerce (toList l))) + +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +rerank :: forall sh sh1 sh2 a b. + (Storable a, Storable b) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f xarr@(XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if 0 `elem` shxToList sh + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a -> let XArray r = f (XArray a) in r) + arr) + +rerankTop :: forall sh1 sh2 sh a b. + (Storable a, Storable b) + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + +-- | The caveat about empty arrays at @rerank@ applies here too. +rerank2 :: forall sh sh1 sh2 a b c. + (Storable a, Storable b, Storable c) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if 0 `elem` shxToList sh + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a b -> let XArray r = f (XArray a) (XArray b) in r) + arr1 arr2) + +-- | The list argument gives indices into the original dimension list. +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) + => StaticShX sh + -> Perm is + -> XArray sh a + -> XArray (PermutePrefix is sh) a +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen ssh perm + = XArray (S.transpose (permToList' perm) arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) + = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" + +transpose2 :: forall sh1 sh2 a. + StaticShX sh1 -> StaticShX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a +sumFull _ (XArray arr) = + S.unScalar $ + liftO1 (numEltSum1Inner (SNat @0)) $ + S.fromVector [product (S.shapeL arr)] $ + S.toVector arr + +sumInner :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' arr + | Refl <- lemAppNil @sh + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShape sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = listxRank (let StaticShX l = ssh in l) + = XArray (liftO1 (numEltSum1Inner sn) arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' arr + | Refl <- lemAppNil @sh + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShape shF) $ + transpose2 (ssxFromShape shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr + +fromListOuter :: forall n sh a. Storable a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l + | Dict <- lemKnownNatRankSSX ssh + = case ssh of + SKnown m :!% _ | fromSNat' m /= length l -> + error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = + case S.shapeL arr of + 0 : _ -> [] + _ -> coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = + let n = length l + in case ssh of + SKnown m :!% _ | fromSNat' m /= n -> + error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr + +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownNatRank sh + , shxSize sh == 0 + = XArray (S.fromVector (shxToList sh) VS.empty) + | otherwise + = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh + +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) + +rev1 :: XArray (n : sh) a -> XArray (n : sh) a +rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shxToList sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) -- cgit v1.2.3-70-g09d2