summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 13:02:29 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 13:02:29 +0200
commit478875e9d82e8c645cbea2e41362c312e892488a (patch)
tree0cc75c3a2e454e4220434b6034242973b884bfce
parent977f0fa379955cbf47fad7279786dea86e24ce43 (diff)
mlift2
-rw-r--r--src/Data/Array/Nested/Internal.hs58
1 files changed, 52 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index eb4ef22..86b0fce 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -147,6 +147,10 @@ class Elt a where
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a
+ mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+
-- ====== PRIVATE METHODS ====== --
-- Remember I said that this module needed better management of exports?
@@ -190,6 +194,15 @@ instance Storable a => Elt (Primitive a) where
, Refl <- X.lemAppNil @sh2
= M_Primitive (f Proxy a)
+ mlift2 :: forall sh1 sh2 sh3.
+ (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
+ mlift2 f (M_Primitive a) (M_Primitive b)
+ | Refl <- X.lemAppNil @sh1
+ , Refl <- X.lemAppNil @sh2
+ , Refl <- X.lemAppNil @sh3
+ = M_Primitive (f Proxy a b)
+
memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
mvecsNumElts _ = 1
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
@@ -215,6 +228,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
+ mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
@@ -248,18 +262,34 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
+ => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift f (M_Nest arr)
| Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
= M_Nest (mlift f' arr)
where
- f' :: forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
+ f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' _
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3))
- = f (Proxy @(sh' ++ sh3))
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
+ = f (Proxy @(sh' ++ shT))
+
+ mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
+ => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
+ mlift2 f (M_Nest arr1) (M_Nest arr2)
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh'))
+ = M_Nest (mlift2 f' arr1 arr2)
+ where
+ f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
+ f' _
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
+ = f (Proxy @(sh' ++ shT))
memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
@@ -384,6 +414,14 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where
= coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift f arr
+ mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 f (M_Ranked arr1) (M_Ranked arr2)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 f arr1 arr2
+
memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
memptyArray i
| Dict <- lemKnownReplicate (Proxy @n)
@@ -479,6 +517,14 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
= coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
mlift f arr
+ mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 f (M_Shaped arr1) (M_Shaped arr2)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 f arr1 arr2
+
memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)
memptyArray i
| Dict <- lemKnownMapJust (Proxy @sh)