From c2fb07100ef8954ef51a5fabfe1b77cd40dd9b61 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 14 May 2024 10:23:06 +0200 Subject: Introduce EltRepr (no more Primitive/Coercible in API) --- src/Data/Array/Nested.hs | 1 + src/Data/Array/Nested/Internal.hs | 83 ++++++++++++++++++++++++++++----------- 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9dc1113..758356c 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -34,6 +34,7 @@ module Data.Array.Nested ( -- * Array elements Elt(mshape, mindex, mindexPartial, mscalar, mfromList, mtoList, mlift, mlift2), + PrimElt, Primitive(..), -- * Inductive natural numbers diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 7a5add7..29592ac 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} @@ -132,6 +133,23 @@ shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx) -- @'Primitive' T@ instead. newtype Primitive a = Primitive a +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class PrimElt a where + fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a + toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce + + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Int +instance PrimElt Double +instance PrimElt () + -- | Mixed arrays: some dimensions are size-typed, some are not. Distributes -- over product-typed elements using a data family so that the full array is @@ -488,8 +506,11 @@ mappend = mlift2 go => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVector sh v = M_Primitive (X.fromVector sh v) +mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive (X.fromVector sh v) + +mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a mfromList1 = mfromList . fmap mscalar @@ -503,11 +524,9 @@ munScalar arr = mindex arr ZIX mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) mconstantP sh x = M_Primitive (X.constant sh x) --- | This 'Coercible' constraint holds for the scalar types for which 'Mixed' --- is defined. -mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) +mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> a -> Mixed sh a -mconstant sh x = coerce (mconstantP sh x) +mconstant sh x = fromPrimitive (mconstantP sh x) mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a mslice ivs = mlift $ \_ -> X.slice ivs @@ -981,10 +1000,10 @@ rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) = Ranked (mlift f arr) -rsumOuter1 :: forall n a. - (Storable a, Num a, KnownINat n) - => Ranked (S n) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1 (Ranked arr) +rsumOuter1P :: forall n a. + (Storable a, Num a, KnownINat n) + => Ranked (S n) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) @@ -992,6 +1011,11 @@ rsumOuter1 (Ranked arr) . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) $ arr +rsumOuter1 :: forall n a. + (Storable a, Num a, PrimElt a, KnownINat n) + => Ranked (S n) a -> Ranked n a +rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive + rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) @@ -1004,10 +1028,13 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend rscalar :: Elt a => a -> Ranked I0 a rscalar x = Ranked (mscalar x) -rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVector sh v +rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromVector (shCvtRX sh) v) + = Ranked (mfromVectorP (shCvtRX sh) v) + +rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a rfromList l @@ -1031,9 +1058,9 @@ rconstantP sh x | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mconstantP (shCvtRX sh) x) -rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) +rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> a -> Ranked n a -rconstant sh x = coerce (rconstantP sh x) +rconstant sh x = coerce fromPrimitive (rconstantP sh x) rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a rslice ivs = rlift $ \_ -> X.slice ivs @@ -1173,10 +1200,10 @@ slift f (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh2) = Shaped (mlift f arr) -ssumOuter1 :: forall sh n a. - (Storable a, Num a, KnownNat n, KnownShape sh) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1 (Shaped arr) +ssumOuter1P :: forall sh n a. + (Storable a, Num a, KnownNat n, KnownShape sh) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) @@ -1184,6 +1211,11 @@ ssumOuter1 (Shaped arr) . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) $ arr +ssumOuter1 :: forall sh n a. + (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive + stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) @@ -1196,10 +1228,13 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) -sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) -sfromVector v +sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP v | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVector (shCvtSX (knownShape @sh)) v) + = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v) + +sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a +sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v) sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) => NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1224,9 +1259,9 @@ sconstantP x | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) -sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) +sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => a -> Shaped sh a -sconstant x = coerce (sconstantP @sh x) +sconstant x = coerce fromPrimitive (sconstantP @sh x) sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a sslice ivs = slift $ \_ -> X.slice ivs -- cgit v1.2.3-70-g09d2