summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-14 10:23:06 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-14 10:23:06 +0200
commitc2fb07100ef8954ef51a5fabfe1b77cd40dd9b61 (patch)
treef58ebb1fb5e0128d7676d8067c8d86c967e83339 /src/Data/Array/Nested/Internal.hs
parent5cd4ed02db25a64ef879e1fa18431360a40de73b (diff)
Introduce EltRepr (no more Primitive/Coercible in API)
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs83
1 files changed, 59 insertions, 24 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 7a5add7..29592ac 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
@@ -132,6 +133,23 @@ shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)
-- @'Primitive' T@ instead.
newtype Primitive a = Primitive a
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class PrimElt a where
+ fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+ toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
+
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Int
+instance PrimElt Double
+instance PrimElt ()
+
-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
-- over product-typed elements using a data family so that the full array is
@@ -488,8 +506,11 @@ mappend = mlift2 go
=> Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
-mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVector sh v = M_Primitive (X.fromVector sh v)
+mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive (X.fromVector sh v)
+
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a
mfromList1 = mfromList . fmap mscalar
@@ -503,11 +524,9 @@ munScalar arr = mindex arr ZIX
mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a)
mconstantP sh x = M_Primitive (X.constant sh x)
--- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'
--- is defined.
-mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
+mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a)
=> IShX sh -> a -> Mixed sh a
-mconstant sh x = coerce (mconstantP sh x)
+mconstant sh x = fromPrimitive (mconstantP sh x)
mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
mslice ivs = mlift $ \_ -> X.slice ivs
@@ -981,10 +1000,10 @@ rlift f (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n2)
= Ranked (mlift f arr)
-rsumOuter1 :: forall n a.
- (Storable a, Num a, KnownINat n)
- => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1 (Ranked arr)
+rsumOuter1P :: forall n a.
+ (Storable a, Num a, KnownINat n)
+ => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked
. coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
@@ -992,6 +1011,11 @@ rsumOuter1 (Ranked arr)
. coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
$ arr
+rsumOuter1 :: forall n a.
+ (Storable a, Num a, PrimElt a, KnownINat n)
+ => Ranked (S n) a -> Ranked n a
+rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
+
rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
@@ -1004,10 +1028,13 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
rscalar :: Elt a => a -> Ranked I0 a
rscalar x = Ranked (mscalar x)
-rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
-rfromVector sh v
+rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromVector (shCvtRX sh) v)
+ = Ranked (mfromVectorP (shCvtRX sh) v)
+
+rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
rfromList l
@@ -1031,9 +1058,9 @@ rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstantP (shCvtRX sh) x)
-rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
+rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
=> IShR n -> a -> Ranked n a
-rconstant sh x = coerce (rconstantP sh x)
+rconstant sh x = coerce fromPrimitive (rconstantP sh x)
rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
rslice ivs = rlift $ \_ -> X.slice ivs
@@ -1173,10 +1200,10 @@ slift f (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh2)
= Shaped (mlift f arr)
-ssumOuter1 :: forall sh n a.
- (Storable a, Num a, KnownNat n, KnownShape sh)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1 (Shaped arr)
+ssumOuter1P :: forall sh n a.
+ (Storable a, Num a, KnownNat n, KnownShape sh)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1P (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
. coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
@@ -1184,6 +1211,11 @@ ssumOuter1 (Shaped arr)
. coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
$ arr
+ssumOuter1 :: forall sh n a.
+ (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
+
stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
@@ -1196,10 +1228,13 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
-sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
-sfromVector v
+sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP v
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromVector (shCvtSX (knownShape @sh)) v)
+ = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v)
+
+sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a
+sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v)
sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
=> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
@@ -1224,9 +1259,9 @@ sconstantP x
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mconstantP (shCvtSX (knownShape @sh)) x)
-sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
+sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
=> a -> Shaped sh a
-sconstant x = coerce (sconstantP @sh x)
+sconstant x = coerce fromPrimitive (sconstantP @sh x)
sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a
sslice ivs = slift $ \_ -> X.slice ivs