diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 10:12:41 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 10:12:41 +0200 |
commit | 3e74b0673caba7c04353c0cedb1d6e02de1fd007 (patch) | |
tree | 36e700e86199dd07b54cd3c28e8d09dc77f32c42 /src/Data/Array/Mixed.hs | |
parent | 8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (diff) |
Move from unboxed to storable vectors
Mikolaj requires this to interface with hmatrix
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 72 |
1 files changed, 36 insertions, 36 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 2875203..2bbf81d 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -12,11 +12,11 @@ {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where -import qualified Data.Array.RankedU as U +import qualified Data.Array.RankedS as S import Data.Kind import Data.Proxy import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU +import qualified Data.Vector.Storable as VS import qualified GHC.TypeLits as GHC import Unsafe.Coerce (unsafeCoerce) @@ -77,7 +77,7 @@ type family Rank sh where Rank (_ : sh) = S (Rank sh) type XArray :: [Maybe GHC.Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) +data XArray sh a = XArray (S.Array (GNat (Rank sh)) a) deriving (Show) zeroIdx :: StaticShapeX sh -> IxX sh @@ -149,7 +149,7 @@ enumShape = \sh -> go sh id [] go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]] go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]] -shapeLshape :: IxX sh -> U.ShapeL +shapeLshape :: IxX sh -> S.ShapeL shapeLshape IZX = [] shapeLshape (n ::@ sh) = n : shapeLshape sh shapeLshape (n ::? sh) = n : shapeLshape sh @@ -197,7 +197,7 @@ lemAppKnownShapeX (() :$? ssh) ssh' = Dict shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh -shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) +shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IxX sh' go SZX [] = IZX @@ -205,48 +205,48 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.fromVector (shapeLshape sh) v) + = XArray (S.fromVector (shapeLshape sh) v) -toVector :: U.Unbox a => XArray sh a -> VU.Vector a -toVector (XArray arr) = U.toVector arr +toVector :: S.Unbox a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr -scalar :: U.Unbox a => a -> XArray '[] a -scalar = XArray . U.scalar +scalar :: S.Unbox a => a -> XArray '[] a +scalar = XArray . S.scalar -unScalar :: U.Unbox a => XArray '[] a -> a -unScalar (XArray a) = U.unScalar a +unScalar :: S.Unbox a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a -generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) +generate :: S.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, S.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . U.fromVector (shapeLshape sh) --- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) +-- XArray . S.fromVector (shapeLshape sh) +-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) -indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial :: S.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a indexPartial (XArray arr) IZX = XArray arr -indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx -indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx +indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx +indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx -index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a +index :: forall sh a. S.Unbox a => XArray sh a -> IxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a - in U.unScalar arr' + in S.unScalar arr' -append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a +append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a append (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.append a b) + = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. - (U.Unbox a, U.Unbox b) + (S.Unbox a, S.Unbox b) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b @@ -259,21 +259,21 @@ rerank ssh ssh1 ssh2 f (XArray arr) , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + = XArray (S.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a rerankTop :: forall sh sh1 sh2 a b. - (U.Unbox a, U.Unbox b) + (S.Unbox a, S.Unbox b) => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. - (U.Unbox a, U.Unbox b, U.Unbox c) + (S.Unbox a, S.Unbox b, S.Unbox c) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c @@ -286,7 +286,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + = XArray (S.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where @@ -297,7 +297,7 @@ transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a transpose perm (XArray arr) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.transpose perm arr) + = XArray (S.transpose perm arr) transpose2 :: forall sh1 sh2 a. StaticShapeX sh1 -> StaticShapeX sh2 @@ -311,18 +311,18 @@ transpose2 ssh1 ssh2 (XArray arr) , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 - = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) -sumFull :: (U.Unbox a, Num a) => XArray sh a -> a -sumFull (XArray arr) = U.sumA arr +sumFull :: (S.Unbox a, Num a) => XArray sh a -> a +sumFull (XArray arr) = S.sumA arr -sumInner :: forall sh sh' a. (U.Unbox a, Num a) +sumInner :: forall sh sh' a. (S.Unbox a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh = rerank ssh ssh' SZX (scalar . sumFull) -sumOuter :: forall sh sh' a. (U.Unbox a, Num a) +sumOuter :: forall sh sh' a. (S.Unbox a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh |