From c2fb07100ef8954ef51a5fabfe1b77cd40dd9b61 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 14 May 2024 10:23:06 +0200
Subject: Introduce EltRepr (no more Primitive/Coercible in API)

---
 src/Data/Array/Nested.hs          |  1 +
 src/Data/Array/Nested/Internal.hs | 83 ++++++++++++++++++++++++++++-----------
 2 files changed, 60 insertions(+), 24 deletions(-)

(limited to 'src/Data/Array')

diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9dc1113..758356c 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -34,6 +34,7 @@ module Data.Array.Nested (
 
   -- * Array elements
   Elt(mshape, mindex, mindexPartial, mscalar, mfromList, mtoList, mlift, mlift2),
+  PrimElt,
   Primitive(..),
 
   -- * Inductive natural numbers
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 7a5add7..29592ac 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1,4 +1,5 @@
 {-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
 {-# LANGUAGE DeriveFunctor #-}
 {-# LANGUAGE DerivingVia #-}
 {-# LANGUAGE FlexibleContexts #-}
@@ -132,6 +133,23 @@ shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)
 -- @'Primitive' T@ instead.
 newtype Primitive a = Primitive a
 
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class PrimElt a where
+  fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+  toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+  default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+  fromPrimitive = coerce
+
+  default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+  toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Int
+instance PrimElt Double
+instance PrimElt ()
+
 
 -- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
 -- over product-typed elements using a data family so that the full array is
@@ -488,8 +506,11 @@ mappend = mlift2 go
            => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
         go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
 
-mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVector sh v = M_Primitive (X.fromVector sh v)
+mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive (X.fromVector sh v)
+
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
 
 mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a
 mfromList1 = mfromList . fmap mscalar
@@ -503,11 +524,9 @@ munScalar arr = mindex arr ZIX
 mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a)
 mconstantP sh x = M_Primitive (X.constant sh x)
 
--- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'
--- is defined.
-mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
+mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a)
           => IShX sh -> a -> Mixed sh a
-mconstant sh x = coerce (mconstantP sh x)
+mconstant sh x = fromPrimitive (mconstantP sh x)
 
 mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
 mslice ivs = mlift $ \_ -> X.slice ivs
@@ -981,10 +1000,10 @@ rlift f (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n2)
   = Ranked (mlift f arr)
 
-rsumOuter1 :: forall n a.
-              (Storable a, Num a, KnownINat n)
-           => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1 (Ranked arr)
+rsumOuter1P :: forall n a.
+               (Storable a, Num a, KnownINat n)
+            => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1P (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked
     . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
@@ -992,6 +1011,11 @@ rsumOuter1 (Ranked arr)
     . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
     $ arr
 
+rsumOuter1 :: forall n a.
+              (Storable a, Num a, PrimElt a, KnownINat n)
+           => Ranked (S n) a -> Ranked n a
+rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
+
 rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
 rtranspose perm (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n)
@@ -1004,10 +1028,13 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
 rscalar :: Elt a => a -> Ranked I0 a
 rscalar x = Ranked (mscalar x)
 
-rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
-rfromVector sh v
+rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
   | Dict <- lemKnownReplicate (Proxy @n)
-  = Ranked (mfromVector (shCvtRX sh) v)
+  = Ranked (mfromVectorP (shCvtRX sh) v)
+
+rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
 
 rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
 rfromList l
@@ -1031,9 +1058,9 @@ rconstantP sh x
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked (mconstantP (shCvtRX sh) x)
 
-rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
+rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
           => IShR n -> a -> Ranked n a
-rconstant sh x = coerce (rconstantP sh x)
+rconstant sh x = coerce fromPrimitive (rconstantP sh x)
 
 rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
 rslice ivs = rlift $ \_ -> X.slice ivs
@@ -1173,10 +1200,10 @@ slift f (Shaped arr)
   | Dict <- lemKnownMapJust (Proxy @sh2)
   = Shaped (mlift f arr)
 
-ssumOuter1 :: forall sh n a.
-              (Storable a, Num a, KnownNat n, KnownShape sh)
-           => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1 (Shaped arr)
+ssumOuter1P :: forall sh n a.
+               (Storable a, Num a, KnownNat n, KnownShape sh)
+            => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1P (Shaped arr)
   | Dict <- lemKnownMapJust (Proxy @sh)
   = Shaped
     . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
@@ -1184,6 +1211,11 @@ ssumOuter1 (Shaped arr)
     . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
     $ arr
 
+ssumOuter1 :: forall sh n a.
+              (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh)
+           => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
+
 stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
 stranspose perm (Shaped arr)
   | Dict <- lemKnownMapJust (Proxy @sh)
@@ -1196,10 +1228,13 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
 sscalar :: Elt a => a -> Shaped '[] a
 sscalar x = Shaped (mscalar x)
 
-sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
-sfromVector v
+sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP v
   | Dict <- lemKnownMapJust (Proxy @sh)
-  = Shaped (mfromVector (shCvtSX (knownShape @sh)) v)
+  = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v)
+
+sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a
+sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v)
 
 sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
           => NonEmpty (Shaped sh a) -> Shaped (n : sh) a
@@ -1224,9 +1259,9 @@ sconstantP x
   | Dict <- lemKnownMapJust (Proxy @sh)
   = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)
 
-sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
+sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
           => a -> Shaped sh a
-sconstant x = coerce (sconstantP @sh x)
+sconstant x = coerce fromPrimitive (sconstantP @sh x)
 
 sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a
 sslice ivs = slift $ \_ -> X.slice ivs
-- 
cgit v1.2.3-70-g09d2