diff options
| -rw-r--r-- | src/Data/Array/Mixed.hs | 72 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 43 | 
2 files changed, 58 insertions, 57 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 2875203..2bbf81d 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -12,11 +12,11 @@  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Mixed where -import qualified Data.Array.RankedU as U +import qualified Data.Array.RankedS as S  import Data.Kind  import Data.Proxy  import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU +import qualified Data.Vector.Storable as VS  import qualified GHC.TypeLits as GHC  import Unsafe.Coerce (unsafeCoerce) @@ -77,7 +77,7 @@ type family Rank sh where    Rank (_ : sh) = S (Rank sh)  type XArray :: [Maybe GHC.Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) +data XArray sh a = XArray (S.Array (GNat (Rank sh)) a)    deriving (Show)  zeroIdx :: StaticShapeX sh -> IxX sh @@ -149,7 +149,7 @@ enumShape = \sh -> go sh id []      go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]]      go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]] -shapeLshape :: IxX sh -> U.ShapeL +shapeLshape :: IxX sh -> S.ShapeL  shapeLshape IZX = []  shapeLshape (n ::@ sh) = n : shapeLshape sh  shapeLshape (n ::? sh) = n : shapeLshape sh @@ -197,7 +197,7 @@ lemAppKnownShapeX (() :$? ssh) ssh'    = Dict  shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh -shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) +shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)    where      go :: StaticShapeX sh' -> [Int] -> IxX sh'      go SZX [] = IZX @@ -205,48 +205,48 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)      go (() :$? ssh) (n : l) = n ::? go ssh l      go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a  fromVector sh v    | Dict <- lemKnownNatRank sh    , Dict <- gknownNat (Proxy @(Rank sh)) -  = XArray (U.fromVector (shapeLshape sh) v) +  = XArray (S.fromVector (shapeLshape sh) v) -toVector :: U.Unbox a => XArray sh a -> VU.Vector a -toVector (XArray arr) = U.toVector arr +toVector :: S.Unbox a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr -scalar :: U.Unbox a => a -> XArray '[] a -scalar = XArray . U.scalar +scalar :: S.Unbox a => a -> XArray '[] a +scalar = XArray . S.scalar -unScalar :: U.Unbox a => XArray '[] a -> a -unScalar (XArray a) = U.unScalar a +unScalar :: S.Unbox a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a -generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) +generate :: S.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, S.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)  -- generateM sh f | Dict <- lemKnownNatRank sh = ---   XArray . U.fromVector (shapeLshape sh) ---     <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) +--   XArray . S.fromVector (shapeLshape sh) +--     <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) -indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial :: S.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a  indexPartial (XArray arr) IZX = XArray arr -indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx -indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx +indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx +indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx -index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a +index :: forall sh a. S.Unbox a => XArray sh a -> IxX sh -> a  index xarr i    | Refl <- lemAppNil @sh    = let XArray arr' = indexPartial xarr i :: XArray '[] a -    in U.unScalar arr' +    in S.unScalar arr' -append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a +append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a  append (XArray a) (XArray b)    | Dict <- lemKnownNatRankSSX (knownShapeX @sh)    , Dict <- gknownNat (Proxy @(Rank sh)) -  = XArray (U.append a b) +  = XArray (S.append a b)  rerank :: forall sh sh1 sh2 a b. -          (U.Unbox a, U.Unbox b) +          (S.Unbox a, S.Unbox b)         => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2         -> (XArray sh1 a -> XArray sh2 b)         -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b @@ -259,21 +259,21 @@ rerank ssh ssh1 ssh2 f (XArray arr)    , Refl <- lemRankApp ssh ssh2    , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)  -- these two should be redundant but the    , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2)))    -- solver is not clever enough -  = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) +  = XArray (S.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))                  (\a -> unXArray (f (XArray a)))                  arr)    where      unXArray (XArray a) = a  rerankTop :: forall sh sh1 sh2 a b. -             (U.Unbox a, U.Unbox b) +             (S.Unbox a, S.Unbox b)            => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh            -> (XArray sh1 a -> XArray sh2 b)            -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b  rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh  rerank2 :: forall sh sh1 sh2 a b c. -           (U.Unbox a, U.Unbox b, U.Unbox c) +           (S.Unbox a, S.Unbox b, S.Unbox c)          => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2          -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)          -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c @@ -286,7 +286,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)    , Refl <- lemRankApp ssh ssh2    , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)  -- these two should be redundant but the    , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2)))  -- solver is not clever enough -  = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) +  = XArray (S.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))                  (\a b -> unXArray (f (XArray a) (XArray b)))                  arr1 arr2)    where @@ -297,7 +297,7 @@ transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a  transpose perm (XArray arr)    | Dict <- lemKnownNatRankSSX (knownShapeX @sh)    , Dict <- gknownNat (Proxy @(Rank sh)) -  = XArray (U.transpose perm arr) +  = XArray (S.transpose perm arr)  transpose2 :: forall sh1 sh2 a.                StaticShapeX sh1 -> StaticShapeX sh2 @@ -311,18 +311,18 @@ transpose2 ssh1 ssh2 (XArray arr)    , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))    , Refl <- lemRankAppComm ssh1 ssh2    , let n1 = ssxLength ssh1 -  = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) +  = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) -sumFull :: (U.Unbox a, Num a) => XArray sh a -> a -sumFull (XArray arr) = U.sumA arr +sumFull :: (S.Unbox a, Num a) => XArray sh a -> a +sumFull (XArray arr) = S.sumA arr -sumInner :: forall sh sh' a. (U.Unbox a, Num a) +sumInner :: forall sh sh' a. (S.Unbox a, Num a)           => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a  sumInner ssh ssh'    | Refl <- lemAppNil @sh    = rerank ssh ssh' SZX (scalar . sumFull) -sumOuter :: forall sh sh' a. (U.Unbox a, Num a) +sumOuter :: forall sh sh' a. (S.Unbox a, Num a)           => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a  sumOuter ssh ssh'    | Refl <- lemAppNil @sh diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index bdded69..41fb1fd 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -35,8 +35,9 @@ import Data.Coerce (coerce, Coercible)  import Data.Kind  import Data.Proxy  import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU -import qualified Data.Vector.Unboxed.Mutable as VUM +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Storable (Storable)  import qualified GHC.TypeLits as GHC  import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) @@ -119,11 +120,11 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))  type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type  data family MixedVecs s sh a -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) -newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) -newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) -newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ())  -- no content, MVector optimises this +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ())  -- no content, MVector optimises this  -- etc.  data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) @@ -143,7 +144,7 @@ class Elt a where    mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a    mlift :: forall sh1 sh2. KnownShapeX sh2 -        => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 a -> Mixed sh2 a    -- ====== PRIVATE METHODS ====== -- @@ -176,7 +177,7 @@ class Elt a where  -- Arrays of scalars are basically just arrays of scalars. -instance VU.Unbox a => Elt (Primitive a) where +instance Storable a => Elt (Primitive a) where    mshape (M_Primitive a) = X.shape a    mindex (M_Primitive a) i = Primitive (X.index a i)    mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) @@ -191,18 +192,18 @@ instance VU.Unbox a => Elt (Primitive a) where    memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))    mvecsNumElts _ = 1 -  mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) -  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x +  mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) +  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x    -- TODO: this use of toVector is suboptimal    mvecsWritePartial -    :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) +    :: forall sh' sh s. KnownShapeX sh'      => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()    mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do      let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) -    VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) +    VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) -  mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v +  mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v  deriving via Primitive Int instance Elt Int  deriving via Primitive Double instance Elt Double @@ -247,13 +248,13 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where      = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)    mlift :: forall sh1 sh2. KnownShapeX sh2 -        => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) +        => (forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)          -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)    mlift f (M_Nest arr)      | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))      = M_Nest (mlift f' arr)      where -      f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b +      f' :: forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b        f' _          | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)          , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) @@ -374,7 +375,7 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where          mindexPartial arr i    mlift :: forall sh1 sh2. KnownShapeX sh2 -        => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)    mlift f (M_Ranked arr)      | Dict <- lemKnownReplicate (Proxy @n) @@ -469,7 +470,7 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where          mindexPartial arr i    mlift :: forall sh1 sh2. KnownShapeX sh2 -        => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)    mlift f (M_Shaped arr)      | Dict <- lemKnownMapJust (Proxy @sh) @@ -568,14 +569,14 @@ rgenerate sh f    = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))  rlift :: forall n1 n2 a. (KnownNat n2, Elt a) -      => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) +      => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)        -> Ranked n1 a -> Ranked n2 a  rlift f (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n2)    = Ranked (mlift f arr)  rsumOuter1 :: forall n a. -              (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) +              (Storable a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))             => Ranked (S n) a -> Ranked n a  rsumOuter1 (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n) @@ -636,14 +637,14 @@ sgenerate sh f    = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))  slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) -      => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) +      => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)        -> Shaped sh1 a -> Shaped sh2 a  slift f (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh2)    = Shaped (mlift f arr)  ssumOuter1 :: forall sh n a. -              (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) +              (Storable a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))             => Shaped (n : sh) a -> Shaped sh a  ssumOuter1 (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh) | 
