diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 90 |
1 files changed, 62 insertions, 28 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index a2787b8..ecc4479 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -698,6 +698,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) @@ -745,18 +746,21 @@ mgenerate sh f = case shxEnum sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = +msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive +msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive + +msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a +msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +msumAllPrim arr = msumAllPrimP (toPrimitive arr) mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -883,24 +887,54 @@ mzip a b munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) munzip (M_Tup2 a b) = (a, b) -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX ssh sh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) - (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) +mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b)) +mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) = + let sh1 = shxDropSh sh shsh1 + in M_Nest sh $ + M_Primitive (shxAppend sh sh2) + (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) --- | See the caveats at 'Data.Array.XArray.rerank'. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr +-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero), +-- then there is no way to deduce the full shape of the output array (more +-- precisely, the @sh2@ part): that could only come from calling @f@, and there +-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we +-- choose to fill the shape with zeros wherever we cannot deduce what it should +-- be. +-- +-- For example, if: +-- +-- @ +-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21] +-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int) +-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float) +-- @ +-- +-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the +-- @0@ in this shape: we don't know if @f@ intended to return an array with +-- shape 0 here (it probably didn't), but there is no better number to put here +-- absent a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b) +mrerankPrim sh2 f (M_Nest sh arr) = + let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr)) + in M_Nest sh' (fromPrimitive arr') mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a @@ -912,12 +946,12 @@ mreplicate sh arr = Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) +mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x) -mreplicateScal :: forall sh a. PrimElt a +mreplicatePrim :: forall sh a. PrimElt a => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) +mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x) msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr |
