aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs90
1 files changed, 62 insertions, 28 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index a2787b8..ecc4479 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -698,6 +698,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
@@ -745,18 +746,21 @@ mgenerate sh f = case shxEnum sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
+msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+
+msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
+msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+msumAllPrim arr = msumAllPrimP (toPrimitive arr)
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -883,24 +887,54 @@ mzip a b
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX ssh sh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2)
- (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
+mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b))
+mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) =
+ let sh1 = shxDropSh sh shsh1
+ in M_Nest sh $
+ M_Primitive (shxAppend sh sh2)
+ (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
--- | See the caveats at 'Data.Array.XArray.rerank'.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero),
+-- then there is no way to deduce the full shape of the output array (more
+-- precisely, the @sh2@ part): that could only come from calling @f@, and there
+-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we
+-- choose to fill the shape with zeros wherever we cannot deduce what it should
+-- be.
+--
+-- For example, if:
+--
+-- @
+-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21]
+-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int)
+-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float)
+-- @
+--
+-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the
+-- @0@ in this shape: we don't know if @f@ intended to return an array with
+-- shape 0 here (it probably didn't), but there is no better number to put here
+-- absent a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
+mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b)
+mrerankPrim sh2 f (M_Nest sh arr) =
+ let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr))
+ in M_Nest sh' (fromPrimitive arr')
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
@@ -912,12 +946,12 @@ mreplicate sh arr =
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x)
-mreplicateScal :: forall sh a. PrimElt a
+mreplicatePrim :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x)
msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr