diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 18 | 
2 files changed, 14 insertions, 7 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9a291e6..292eba4 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -38,8 +38,7 @@ module Data.Array.Nested (    Mixed,    IxX(..), IIxX,    KnownShX(..), StaticShX(..), -  -- TODO: missing msumOuter1? -  mshape, mindex, mindexPartial, mgenerate, +  mshape, mindex, mindexPartial, mgenerate, msumOuter1,    mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,    mrerank,    mreplicate, mfromListOuter, mfromList1, mtoListOuter, mtoList1, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 3ff7967..f3f8f7d 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -884,6 +884,16 @@ mgenerate sh f = case X.enumShape sh of                    mvecsWrite sh idx val vecs                  mvecsFreeze sh vecs +msumOuter1P :: forall sh n a. (Storable a, Num a) +            => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = +  let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX +  in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) + +msumOuter1 :: forall sh n a. (Num a, PrimElt a) +           => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive +  mappend :: forall n m sh a. Elt a          => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a  mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 @@ -1458,10 +1468,9 @@ rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f ar  rsumOuter1P :: forall n a.                 (Storable a, Num a)              => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked (M_Primitive sh arr)) +rsumOuter1P (Ranked arr)    | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n -  , _ :$% shT <- sh -  = Ranked (M_Primitive shT (X.sumOuter (SUnknown () :!% ZKX) (X.staticShapeFrom shT) arr)) +  = Ranked (msumOuter1P arr)  rsumOuter1 :: forall n a. (Num a, PrimElt a)             => Ranked (n + 1) a -> Ranked n a @@ -1741,8 +1750,7 @@ slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (sh  ssumOuter1P :: forall sh n a. (Storable a, Num a)              => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) = -  Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr)) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)  ssumOuter1 :: forall sh n a. (Num a, PrimElt a)             => Shaped (n : sh) a -> Shaped sh a | 
