aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 7f25d84..d9eb5f0 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -111,6 +111,12 @@ shapeSize IZX = 1
shapeSize (n ::@ sh) = n * shapeSize sh
shapeSize (n ::? sh) = n * shapeSize sh
+-- | This may fail if @sh@ has @Nothing@s in it.
+ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh)
+ssxToShape' SZX = Just IZX
+ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh
+ssxToShape' (_ :$? _) = Nothing
+
fromLinearIdx :: IxX sh -> Int -> IxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -221,6 +227,12 @@ scalar = XArray . S.scalar
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
+constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a
+constant sh x
+ | Dict <- lemKnownINatRank sh
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ = XArray (S.constant (shapeLshape sh) x)
+
generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)