{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.XArray where import Control.DeepSeq (NFData(..)) import Data.Array.Ranked qualified as ORB import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Type.Ord import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show, Eq, Generic) -- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. deriving instance (Ord a, Storable a) => Ord (XArray sh a) instance NFData a => NFData (XArray sh a) shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' go ZKX [] = ZSX go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shxToList sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar -- | Will throw if the array does not have the casted-to shape. cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> StaticShX sh' -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (ssxFromShape sh2) ssh' = let arrsh :: IShX sh1 (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') , Refl <- lemRankApp (ssxFromShape sh) ssh' = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ arr) replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a replicateScal sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shxToList sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shxShapeL sh) -- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX ssh = XArray (S.append a b) -- | All arrays must have the same shape, except possibly for the outermost -- dimension. concat :: Storable a => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a concat ssh l | Dict <- lemKnownNatRankSSX ssh = XArray (S.concatOuter (coerce (toList l))) -- | If the prefix of the shape of the input array (@sh@) is empty (i.e. -- contains a zero), then there is no way to deduce the full shape of the output -- array (more precisely, the @sh2@ part): that could only come from calling -- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in -- this case; we choose to fill the shape with zeros wherever we cannot deduce -- what it should be. -- -- For example, if: -- -- @ -- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] -- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float -- @ -- -- then: -- -- @ -- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float -- @ -- -- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ -- in this shape: we don't know if @f@ intended to return an array with shape 0 -- here (it probably didn't), but there is no better number to put here absent -- a subarray of the input to pass to @f@. -- -- In this particular case the fact that @sh@ is empty was evident from the -- type-level information, but the same situation occurs when @sh@ consists of -- @Nothing@s, and some of those happen to be zero at runtime. rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) in if any (== 0) (shxToList sh) then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> let XArray r = f (XArray a) in r) arr) rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh -- | The caveat about empty arrays at @rerank@ applies here too. rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) in if any (== 0) (shxToList sh) then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> let XArray r = f (XArray a) (XArray b) in r) arr1 arr2) -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) => StaticShX sh -> Perm is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) | Dict <- lemKnownNatRankSSX ssh , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm = XArray (S.transpose (permToList' perm) arr) -- | The list argument gives indices into the original dimension list. -- -- The permutation (the list) must have length <= @n@. If it is longer, this -- function throws. transposeUntyped :: forall n sh a. SNat n -> StaticShX sh -> [Int] -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a transposeUntyped sn ssh perm (XArray arr) | length perm <= fromSNat' sn , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) | otherwise = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = S.unScalar $ numEltSum1Inner (SNat @0) $ S.fromVector [product (S.shapeL arr)] $ S.toVector arr sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX ssh'F = ssxFromShape sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F , let sn = listxRank (let StaticShX l = ssh in l) = XArray (numEltSum1Inner sn arr') in go $ transpose2 ssh'F ssh $ reshapePartial ssh' ssh sh'F $ transpose2 ssh ssh' $ arr sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX in sumInner ssh' (ssxFromShape shF) $ transpose2 (ssxFromShape shF) ssh' $ reshapePartial ssh ssh' shF $ arr fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromListOuter ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of SKnown m :!% _ | fromSNat' m /= length l -> error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (fromSNat' m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] toListOuter (XArray arr) = case S.shapeL arr of 0 : _ -> [] _ -> coerce (ORB.toList (S.unravel arr)) fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a fromList1 ssh l = let n = length l in case ssh of SKnown m :!% _ | fromSNat' m /= n -> error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ "does not match the type (" ++ show (fromSNat' m) ++ ")" _ -> XArray (S.fromVector [n] (VS.fromListN n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh = XArray (S.constant (shxToList sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 = XArray (S.reshape (shxToList sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))