From 8d495b7e6c21fc843f0538711c2203dfb213b7e1 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 21:20:08 +0200 Subject: Generalise *rerank to type-changing functions --- src/Data/Array/Nested/Internal.hs | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 9b505ea..feb0662 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -899,10 +899,10 @@ mtoList = map munScalar . mtoList1 munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX -mrerankP :: forall sh1 sh2 sh a. Storable a +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive a) + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = let sh1 = shDropSSX sh ssh in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2) @@ -910,10 +910,10 @@ mrerankP ssh sh2 f (M_Primitive sh arr) = (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) -mrerank :: forall sh1 sh2 sh a. PrimElt a +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 a) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b mrerank ssh sh2 f (toPrimitive -> arr) = fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr @@ -1412,10 +1412,10 @@ rfromOrthotope sn arr runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR -rrerankP :: forall n1 n2 n a. Storable a +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive a)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive a) + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) rrerankP sn sh2 f (Ranked arr) | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) @@ -1423,10 +1423,10 @@ rrerankP sn sh2 f (Ranked arr) (\a -> let Ranked r = f (Ranked a) in r) arr) -rrerank :: forall n1 n2 n a. PrimElt a +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 a) - -> Ranked (n + n1) a -> Ranked (n + n2) a + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked (n + n1) a -> Ranked (n + n2) b rrerank ssh sh2 f (rtoPrimitive -> arr) = rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr @@ -1666,10 +1666,10 @@ stoList1 = map sunScalar . stoList sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS -srerankP :: forall sh1 sh2 sh a. Storable a +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive a)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive a) + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemCommMapJustApp sh (Proxy @sh1) , Refl <- lemCommMapJustApp sh (Proxy @sh2) @@ -1678,10 +1678,10 @@ srerankP sh sh2 f sarr@(Shaped arr) (\a -> let Shaped r = f (Shaped a) in r) arr) -srerank :: forall sh1 sh2 sh a. PrimElt a +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 a) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) a + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b srerank sh sh2 f (stoPrimitive -> arr) = sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr -- cgit v1.2.3-70-g09d2