From 70d2edeb338c6acbe9943c4f8b24225bcb912211 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 14 Apr 2024 13:03:53 +0200
Subject: Num instances for Mixed, Ranked, Shaped

---
 src/Data/Array/Mixed.hs           | 12 +++++++
 src/Data/Array/Nested/Internal.hs | 72 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 84 insertions(+)

(limited to 'src/Data/Array')

diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 7f25d84..d9eb5f0 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -111,6 +111,12 @@ shapeSize IZX = 1
 shapeSize (n ::@ sh) = n * shapeSize sh
 shapeSize (n ::? sh) = n * shapeSize sh
 
+-- | This may fail if @sh@ has @Nothing@s in it.
+ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh)
+ssxToShape' SZX = Just IZX
+ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh
+ssxToShape' (_ :$? _) = Nothing
+
 fromLinearIdx :: IxX sh -> Int -> IxX sh
 fromLinearIdx = \sh i -> case go sh i of
   (idx, 0) -> idx
@@ -221,6 +227,12 @@ scalar = XArray . S.scalar
 unScalar :: Storable a => XArray '[] a -> a
 unScalar (XArray a) = S.unScalar a
 
+constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a
+constant sh x
+  | Dict <- lemKnownINatRank sh
+  , Dict <- knownNatFromINat (Proxy @(Rank sh))
+  = XArray (S.constant (shapeLshape sh) x)
+
 generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a
 generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
 
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 86b0fce..f7c383a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -31,6 +31,7 @@ module Data.Array.Nested.Internal where
 
 import Control.Monad (forM_)
 import Control.Monad.ST
+import qualified Data.Array.RankedS as S
 import Data.Coerce (coerce, Coercible)
 import Data.Kind
 import Data.Proxy
@@ -353,6 +354,33 @@ mtranspose perm =
   mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
                             (X.transpose perm))
 
+mliftPrim :: (KnownShapeX sh, Storable a)
+          => (a -> a)
+          -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
+mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: (KnownShapeX sh, Storable a)
+           => (a -> a -> a)
+           -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
+mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) =
+  M_Primitive (X.XArray (S.zipWithA f arr1 arr2))
+
+instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where
+  (+) = mliftPrim2 (+)
+  (-) = mliftPrim2 (-)
+  (*) = mliftPrim2 (*)
+  negate = mliftPrim negate
+  abs = mliftPrim abs
+  signum = mliftPrim signum
+  fromInteger n =
+    case X.ssxToShape' (knownShapeX @sh) of
+      Just sh -> M_Primitive (X.constant sh (fromInteger n))
+      Nothing -> error "Data.Array.Nested.fromIntegral: \
+                       \Unknown components in shape, use explicit replicate"
+
+deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)
+deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double)
+
 
 -- | A rank-typed array: the number of dimensions of the array (its /rank/) is
 -- represented on the type level as a 'INat'.
@@ -582,6 +610,28 @@ coerceMixedXArray = coerce
 
 -- ====== API OF RANKED ARRAYS ====== --
 
+arithPromoteRanked :: forall n a. KnownINat n
+                   => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
+                   -> Ranked n a -> Ranked n a
+arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
+
+arithPromoteRanked2 :: forall n a. KnownINat n
+                    => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
+                    -> Ranked n a -> Ranked n a -> Ranked n a
+arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
+
+instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+  (+) = arithPromoteRanked2 (+)
+  (-) = arithPromoteRanked2 (-)
+  (*) = arithPromoteRanked2 (*)
+  negate = arithPromoteRanked negate
+  abs = arithPromoteRanked abs
+  signum = arithPromoteRanked signum
+  fromInteger n | Dict <- lemKnownReplicate (Proxy @n) = Ranked (fromInteger n)
+
+deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+
 -- | An index into a rank-typed array.
 type IxR :: INat -> Type
 data IxR n where
@@ -646,6 +696,28 @@ rtranspose perm (Ranked arr)
 
 -- ====== API OF SHAPED ARRAYS ====== --
 
+arithPromoteShaped :: forall sh a. KnownShape sh
+                   => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a)
+                   -> Shaped sh a -> Shaped sh a
+arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+
+arithPromoteShaped2 :: forall sh a. KnownShape sh
+                    => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a)
+                    -> Shaped sh a -> Shaped sh a -> Shaped sh a
+arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce
+
+instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where
+  (+) = arithPromoteShaped2 (+)
+  (-) = arithPromoteShaped2 (-)
+  (*) = arithPromoteShaped2 (*)
+  negate = arithPromoteShaped negate
+  abs = arithPromoteShaped abs
+  signum = arithPromoteShaped signum
+  fromInteger n | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (fromInteger n)
+
+deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
+deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
+
 -- | An index into a shape-typed array.
 --
 -- For convenience, this contains regular 'Int's instead of bounded integers
-- 
cgit v1.2.3-70-g09d2