aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 22:47:52 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 22:47:52 +0200
commit8b59d8ef4ff97936f2a753d1ce345e0404c26b2b (patch)
tree947f75cb43982fbdb551dc329f036b0591f3c2b2 /src/Data/Array/Mixed.hs
parentf0752d67cd188f438280e1f0c692dc1f5f14a190 (diff)
Clearer module purposes
Thanks Mikolaj for discussion
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs335
1 files changed, 0 insertions, 335 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
deleted file mode 100644
index 4a338a2..0000000
--- a/src/Data/Array/Mixed.hs
+++ /dev/null
@@ -1,335 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed where
-
-import Control.DeepSeq (NFData(..))
-import Data.Array.Ranked qualified as ORB
-import Data.Array.RankedS qualified as S
-import Data.Coerce
-import Data.Kind
-import Data.Proxy
-import Data.Type.Equality
-import Data.Type.Ord
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-
-
-type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (Rank sh) a)
- deriving (Show, Eq, Generic)
-
--- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
-deriving instance (Ord a, Storable a) => Ord (XArray '[] a)
-
-instance NFData a => NFData (XArray sh a)
-
-
-shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
-shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
- where
- go :: StaticShX sh' -> [Int] -> IShX sh'
- go ZKX [] = ZSX
- go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
- go _ _ = error "Invalid shapeL"
-
-fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
-fromVector sh v
- | Dict <- lemKnownNatRank sh
- = XArray (S.fromVector (shxToList sh) v)
-
-toVector :: Storable a => XArray sh a -> VS.Vector a
-toVector (XArray arr) = S.toVector arr
-
-scalar :: Storable a => a -> XArray '[] a
-scalar = XArray . S.scalar
-
--- | Will throw if the array does not have the casted-to shape.
-cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> StaticShX sh'
- -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
-cast ssh1 sh2 ssh' (XArray arr)
- | Refl <- lemRankApp ssh1 ssh'
- , Refl <- lemRankApp (ssxFromShape sh2) ssh'
- = let arrsh :: IShX sh1
- (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
- in if shxToList arrsh == shxToList sh2
- then XArray arr
- else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
-
-unScalar :: Storable a => XArray '[] a -> a
-unScalar (XArray a) = S.unScalar a
-
-replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
-replicate sh ssh' (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh'
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
- , Refl <- lemRankApp (ssxFromShape sh) ssh'
- = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
- S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $
- arr)
-
-replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
-replicateScal sh x
- | Dict <- lemKnownNatRank sh
- = XArray (S.constant (shxToList sh) x)
-
-generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
-generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-
--- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownNatRank sh =
--- XArray . S.fromVector (shxShapeL sh)
--- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
-
-indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
-indexPartial (XArray arr) ZIX = XArray arr
-indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
-
-index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
-index xarr i
- | Refl <- lemAppNil @sh
- = let XArray arr' = indexPartial xarr i :: XArray '[] a
- in S.unScalar arr'
-
-append :: forall n m sh a. Storable a
- => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
-append ssh (XArray a) (XArray b)
- | Dict <- lemKnownNatRankSSX ssh
- = XArray (S.append a b)
-
--- | If the prefix of the shape of the input array (@sh@) is empty (i.e.
--- contains a zero), then there is no way to deduce the full shape of the output
--- array (more precisely, the @sh2@ part): that could only come from calling
--- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
--- this case; we choose to fill the shape with zeros wherever we cannot deduce
--- what it should be.
---
--- For example, if:
---
--- @
--- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21]
--- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float
--- @
---
--- then:
---
--- @
--- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float
--- @
---
--- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@
--- in this shape: we don't know if @f@ intended to return an array with shape 0
--- here (it probably didn't), but there is no better number to put here absent
--- a subarray of the input to pass to @f@.
---
--- In this particular case the fact that @sh@ is empty was evident from the
--- type-level information, but the same situation occurs when @sh@ consists of
--- @Nothing@s, and some of those happen to be zero at runtime.
-rerank :: forall sh sh1 sh2 a b.
- (Storable a, Storable b)
- => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
- -> (XArray sh1 a -> XArray sh2 b)
- -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
-rerank ssh ssh1 ssh2 f xarr@(XArray arr)
- | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
- in if any (== 0) (shxToList sh)
- then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
- else case () of
- () | Dict <- lemKnownNatRankSSX ssh
- , Dict <- lemKnownNatRankSSX ssh2
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
- (\a -> let XArray r = f (XArray a) in r)
- arr)
-
-rerankTop :: forall sh1 sh2 sh a b.
- (Storable a, Storable b)
- => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
- -> (XArray sh1 a -> XArray sh2 b)
- -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
-rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
-
--- | The caveat about empty arrays at @rerank@ applies here too.
-rerank2 :: forall sh sh1 sh2 a b c.
- (Storable a, Storable b, Storable c)
- => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
- -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
- -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
-rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
- | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
- in if any (== 0) (shxToList sh)
- then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
- else case () of
- () | Dict <- lemKnownNatRankSSX ssh
- , Dict <- lemKnownNatRankSSX ssh2
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
- (\a b -> let XArray r = f (XArray a) (XArray b) in r)
- arr1 arr2)
-
-class KnownNatList l where makeNatList :: Perm l
-instance KnownNatList '[] where makeNatList = PNil
-instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList
-
--- | The list argument gives indices into the original dimension list.
-transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh)
- => StaticShX sh
- -> Perm is
- -> XArray sh a
- -> XArray (PermutePrefix is sh) a
-transpose ssh perm (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh
- , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
- , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
- , Refl <- lemRankDropLen ssh perm
- = XArray (S.transpose (permToList' perm) arr)
-
--- | The list argument gives indices into the original dimension list.
---
--- The permutation (the list) must have length <= @n@. If it is longer, this
--- function throws.
-transposeUntyped :: forall n sh a.
- SNat n -> StaticShX sh -> [Int]
- -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
-transposeUntyped sn ssh perm (XArray arr)
- | length perm <= fromSNat' sn
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
- = XArray (S.transpose perm arr)
- | otherwise
- = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
-
-transpose2 :: forall sh1 sh2 a.
- StaticShX sh1 -> StaticShX sh2
- -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
-transpose2 ssh1 ssh2 (XArray arr)
- | Refl <- lemRankApp ssh1 ssh2
- , Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
- , Refl <- lemRankAppComm ssh1 ssh2
- , let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-
-sumFull :: (Storable a, NumElt a) => XArray sh a -> a
-sumFull (XArray arr) =
- S.unScalar $
- numEltSum1Inner (SNat @0) $
- S.fromVector [product (S.shapeL arr)] $
- S.toVector arr
-
-sumInner :: forall sh sh' a. (Storable a, NumElt a)
- => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
-sumInner ssh ssh' arr
- | Refl <- lemAppNil @sh
- = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
- sh'F = shxFlatten sh' :$% ZSX
- ssh'F = ssxFromShape sh'F
-
- go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
- go (XArray arr')
- | Refl <- lemRankApp ssh ssh'F
- , let sn = listxLengthSNat (let StaticShX l = ssh in l)
- = XArray (numEltSum1Inner sn arr')
-
- in go $
- transpose2 ssh'F ssh $
- reshapePartial ssh' ssh sh'F $
- transpose2 ssh ssh' $
- arr
-
-sumOuter :: forall sh sh' a. (Storable a, NumElt a)
- => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
-sumOuter ssh ssh' arr
- | Refl <- lemAppNil @sh
- = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
- shF = shxFlatten sh :$% ZSX
- in sumInner ssh' (ssxFromShape shF) $
- transpose2 (ssxFromShape shF) ssh' $
- reshapePartial ssh ssh' shF $
- arr
-
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX ssh
- = case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
-
-toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr) =
- case S.shapeL arr of
- 0 : _ -> []
- _ -> coerce (ORB.toList (S.unravel arr))
-
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
-
-toList1 :: Storable a => XArray '[n] a -> [a]
-toList1 (XArray arr) = S.toList arr
-
--- | Throws if the given shape is not, in fact, empty.
-empty :: forall sh a. Storable a => IShX sh -> XArray sh a
-empty sh
- | Dict <- lemKnownNatRank sh
- = XArray (S.constant (shxToList sh)
- (error "Data.Array.Mixed.empty: shape was not empty"))
-
-slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
-slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr)
-
-sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a
-sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr)
-
-rev1 :: XArray (n : sh) a -> XArray (n : sh) a
-rev1 (XArray arr) = XArray (S.rev [0] arr)
-
--- | Throws if the given array and the target shape do not have the same number of elements.
-reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
-reshape ssh1 sh2 (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh1
- , Dict <- lemKnownNatRank sh2
- = XArray (S.reshape (shxToList sh2) arr)
-
--- | Throws if the given array and the target shape do not have the same number of elements.
-reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
-reshapePartial ssh1 ssh' sh2 (XArray arr)
- | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
- = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
-
--- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
-iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a
-iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))