diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 13:02:29 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 13:02:29 +0200 |
commit | 478875e9d82e8c645cbea2e41362c312e892488a (patch) | |
tree | 0cc75c3a2e454e4220434b6034242973b884bfce | |
parent | 977f0fa379955cbf47fad7279786dea86e24ce43 (diff) |
mlift2
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 58 |
1 files changed, 52 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index eb4ef22..86b0fce 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -147,6 +147,10 @@ class Elt a where => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a + mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + -- ====== PRIVATE METHODS ====== -- -- Remember I said that this module needed better management of exports? @@ -190,6 +194,15 @@ instance Storable a => Elt (Primitive a) where , Refl <- X.lemAppNil @sh2 = M_Primitive (f Proxy a) + mlift2 :: forall sh1 sh2 sh3. + (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) + mlift2 f (M_Primitive a) (M_Primitive b) + | Refl <- X.lemAppNil @sh1 + , Refl <- X.lemAppNil @sh2 + , Refl <- X.lemAppNil @sh3 + = M_Primitive (f Proxy a b) + memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) mvecsNumElts _ = 1 mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) @@ -215,6 +228,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) + mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y @@ -248,18 +262,34 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) + => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift f (M_Nest arr) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) = M_Nest (mlift f' arr) where - f' :: forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b + f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) - = f (Proxy @(sh' ++ sh3)) + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) + = f (Proxy @(sh' ++ shT)) + + mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) + => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) + mlift2 f (M_Nest arr1) (M_Nest arr2) + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) + , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) + = M_Nest (mlift2 f' arr1 arr2) + where + f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' _ + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) + = f (Proxy @(sh' ++ shT)) memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) @@ -384,6 +414,14 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift f arr + mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 f (M_Ranked arr1) (M_Ranked arr2) + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 f arr1 arr2 + memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) @@ -479,6 +517,14 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift f arr + mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 f (M_Shaped arr1) (M_Shaped arr2) + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 f arr1 arr2 + memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) |