diff options
| -rw-r--r-- | src/Data/Array/Mixed.hs | 13 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 63 | 
3 files changed, 56 insertions, 26 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 748914c..d894b85 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr)  unScalar :: Storable a => XArray '[] a -> a  unScalar (XArray a) = S.unScalar a -replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicate sh x +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) +  | Dict <- lemKnownNatRankSSX ssh' +  , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') +  , Refl <- lemRankApp (staticShapeFrom sh) ssh' +  = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ +            S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ +              arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x    | Dict <- lemKnownNatRank sh    = XArray (S.constant (shapeLshape sh) x) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 49923d8..968ea18 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,7 +9,7 @@ module Data.Array.Nested (    rshape, rindex, rindexPartial, rgenerate, rsumOuter1,    rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,    rrerank, -  rreplicate, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, +  rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,    rslice, rrev1, rreshape, riota,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, rlift2, @@ -26,7 +26,7 @@ module Data.Array.Nested (    sshape, sindex, sindexPartial, sgenerate, ssumOuter1,    stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,    srerank, -  sreplicate, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, +  sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,    sslice, srev1, sreshape, siota,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, slift2, @@ -41,7 +41,7 @@ module Data.Array.Nested (    mshape, mindex, mindexPartial, mgenerate, msumOuter1,    mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,    mrerank, -  mreplicate, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, +  mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,    mslice, mrev1, mreshape, miota,    -- ** Lifting orthotope operations to 'Mixed' arrays    mlift, mlift2, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index c70da54..b7308fa 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -967,12 +967,22 @@ mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)  mrerank ssh sh2 f (toPrimitive -> arr) =    fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr -mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateP sh x = M_Primitive sh (X.replicate sh x) +mreplicate :: forall sh sh' a. Elt a +           => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = +  let ssh' = X.staticShapeFrom (mshape arr) +  in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') +           (\(sshT :: StaticShX shT) -> +              case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of +                Refl -> X.replicate sh (ssxAppend ssh' sshT)) +           arr -mreplicate :: forall sh a. PrimElt a -          => IShX sh -> a -> Mixed sh a -mreplicate sh x = fromPrimitive (mreplicateP sh x) +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a +               => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)  mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a  mslice i n arr = @@ -1352,15 +1362,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where    negate = arithPromoteRanked negate    abs = arithPromoteRanked abs    signum = arithPromoteRanked signum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" +  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"  instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where -  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" +  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"    recip = arithPromoteRanked recip    (/) = arithPromoteRanked2 (/)  instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" +  pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"    exp = arithPromoteRanked exp    log = arithPromoteRanked log    sqrt = arithPromoteRanked sqrt @@ -1576,14 +1586,20 @@ rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)  rrerank ssh sh2 f (rtoPrimitive -> arr) =    rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr -rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateP sh x +rreplicate :: forall n m a. Elt a +           => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) +  | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat)) +  = Ranked (mreplicate (shCvtRX sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x    | Dict <- lemKnownReplicate (snatFromShR sh) -  = Ranked (mreplicateP (shCvtRX sh) x) +  = Ranked (mreplicateScalP (shCvtRX sh) x) -rreplicate :: forall n a. PrimElt a -          => IShR n -> a -> Ranked n a -rreplicate sh x = rfromPrimitive (rreplicateP sh x) +rreplicateScal :: forall n a. PrimElt a +               => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)  rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a  rslice i n arr @@ -1654,15 +1670,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where    negate = arithPromoteShaped negate    abs = arithPromoteShaped abs    signum = arithPromoteShaped signum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" +  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"  instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where -  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate" +  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"    recip = arithPromoteShaped recip    (/) = arithPromoteShaped2 (/)  instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate" +  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"    exp = arithPromoteShaped exp    log = arithPromoteShaped log    sqrt = arithPromoteShaped sqrt @@ -1884,11 +1900,16 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)  srerank sh sh2 f (stoPrimitive -> arr) =    sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr -sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x) +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) +  | Refl <- lemCommMapJustApp sh (Proxy @sh') +  = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) -sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicate sh x = sfromPrimitive (sreplicateP sh x) +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)  sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a  sslice i n@SNat arr = | 
