diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 42 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 11 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 9 | 
3 files changed, 60 insertions, 2 deletions
| diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index b799190..4746f31 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -263,7 +263,7 @@ class Elt a where    mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]    -- | Note: this library makes no particular guarantees about the shapes of -  -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the +  -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the    -- full 'XArray' and as such you can distinguish different empty arrays by    -- the "shapes" of their elements. This information is meaningless, so you    -- should not use it. @@ -278,6 +278,14 @@ class Elt a where           -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)           -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a +  -- TODO: mliftL is currently unused. +  -- | All arrays in the input must have equal shapes, including subarrays +  -- inside their elements. +  mliftL :: forall sh1 sh2. +            StaticShX sh2 +         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) +         -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) +    mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a @@ -364,6 +372,16 @@ instance Storable a => Elt (Primitive a) where      , let result = f ZKX a b      = M_Primitive (X.shape ssh3 result) result +  mliftL :: forall sh1 sh2. +            StaticShX sh2 +         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) +         -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) +  mliftL ssh2 f l +    | Refl <- lemAppNil @sh1 +    , Refl <- lemAppNil @sh2 +    = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ +        f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) +    mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)    mcast ssh1 sh2 _ (M_Primitive sh1' arr) = @@ -432,6 +450,11 @@ instance (Elt a, Elt b) => Elt (a, b) where    mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)    mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)    mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) +  mliftL ssh2 f = +    let unzipT2l [] = ([], []) +        unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) +        unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) +    in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2    mcast ssh1 sh2 psh' (M_Tup2 a b) =      M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) @@ -523,6 +546,23 @@ instance Elt a => Elt (Mixed sh' a) where          , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)          = f (ssxAppend ssh' sshT) +  mliftL :: forall sh1 sh2. +            StaticShX sh2 +         -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) +         -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) +  mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = +    let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) +        (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) +    in fmap (M_Nest sh2) result +    where +      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + +      f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) +      f' sshT +        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        = f (ssxAppend ssh' sshT) +    mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)    mcast ssh1 sh2 _ (M_Nest sh1T arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index d6e05e6..55ae59f 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -103,7 +103,16 @@ instance Elt a => Elt (Ranked n a) where           -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)    mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =      coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ -        mlift2 ssh3 f arr1 arr2 +      mlift2 ssh3 f arr1 arr2 + +  mliftL :: forall sh1 sh2. +            StaticShX sh2 +         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) +         -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) +  mliftL ssh2 f l = +    coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) +           @(NonEmpty (Mixed sh2 (Ranked n a))) $ +      mliftL ssh2 f (coerce l)    mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d1881c1..544a2fa 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -100,6 +100,15 @@ instance Elt a => Elt (Shaped sh a) where      coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $        mlift2 ssh3 f arr1 arr2 +  mliftL :: forall sh1 sh2. +            StaticShX sh2 +         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) +         -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) +  mliftL ssh2 f l = +    coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) +           @(NonEmpty (Mixed sh2 (Shaped sh a))) $ +      mliftL ssh2 f (coerce l) +    mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)    mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) | 
