From 70d2edeb338c6acbe9943c4f8b24225bcb912211 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 13:03:53 +0200 Subject: Num instances for Mixed, Ranked, Shaped --- src/Data/Array/Mixed.hs | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 7f25d84..d9eb5f0 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -111,6 +111,12 @@ shapeSize IZX = 1 shapeSize (n ::@ sh) = n * shapeSize sh shapeSize (n ::? sh) = n * shapeSize sh +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh) +ssxToShape' SZX = Just IZX +ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh +ssxToShape' (_ :$? _) = Nothing + fromLinearIdx :: IxX sh -> Int -> IxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx @@ -221,6 +227,12 @@ scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a +constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a +constant sh x + | Dict <- lemKnownINatRank sh + , Dict <- knownNatFromINat (Proxy @(Rank sh)) + = XArray (S.constant (shapeLshape sh) x) + generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- cgit v1.2.3-70-g09d2