diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 58 | 
1 files changed, 52 insertions, 6 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index eb4ef22..86b0fce 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -147,6 +147,10 @@ class Elt a where          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 a -> Mixed sh2 a +  mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) +         => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a +    -- ====== PRIVATE METHODS ====== --    -- Remember I said that this module needed better management of exports? @@ -190,6 +194,15 @@ instance Storable a => Elt (Primitive a) where      , Refl <- X.lemAppNil @sh2      = M_Primitive (f Proxy a) +  mlift2 :: forall sh1 sh2 sh3. +            (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) +         -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) +  mlift2 f (M_Primitive a) (M_Primitive b) +    | Refl <- X.lemAppNil @sh1 +    , Refl <- X.lemAppNil @sh2 +    , Refl <- X.lemAppNil @sh3 +    = M_Primitive (f Proxy a b) +    memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))    mvecsNumElts _ = 1    mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) @@ -215,6 +228,7 @@ instance (Elt a, Elt b) => Elt (a, b) where    mindex (M_Tup2 a b) i = (mindex a i, mindex b i)    mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)    mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) +  mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)    memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)    mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y @@ -248,18 +262,34 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where      = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)    mlift :: forall sh1 sh2. KnownShapeX sh2 -        => (forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) +        => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)          -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)    mlift f (M_Nest arr)      | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))      = M_Nest (mlift f' arr)      where -      f' :: forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b +      f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b +      f' _ +        | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) +        = f (Proxy @(sh' ++ shT)) + +  mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) +         => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) +         -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) +  mlift2 f (M_Nest arr1) (M_Nest arr2) +    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) +    , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) +    = M_Nest (mlift2 f' arr1 arr2) +    where +      f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b        f' _ -        | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) -        , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) -        , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) -        = f (Proxy @(sh' ++ sh3)) +        | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) +        , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) +        = f (Proxy @(sh' ++ shT))    memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) @@ -384,6 +414,14 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where      = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $          mlift f arr +  mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) +         => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) +  mlift2 f (M_Ranked arr1) (M_Ranked arr2) +    | Dict <- lemKnownReplicate (Proxy @n) +    = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ +        mlift2 f arr1 arr2 +    memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)    memptyArray i      | Dict <- lemKnownReplicate (Proxy @n) @@ -479,6 +517,14 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where      = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $          mlift f arr +  mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) +         => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) +  mlift2 f (M_Shaped arr1) (M_Shaped arr2) +    | Dict <- lemKnownMapJust (Proxy @sh) +    = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ +        mlift2 f arr1 arr2 +    memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)    memptyArray i      | Dict <- lemKnownMapJust (Proxy @sh) | 
