From ae113c0249f3fe8be7df345081b1b51451cd3fdf Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 27 Mar 2024 22:58:51 +0100 Subject: Ranked interface --- src/Array.hs | 1 + src/Fancy.hs | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- src/Nats.hs | 4 ++ 3 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/Array.hs b/src/Array.hs index 5140eaf..29806d4 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -43,6 +43,7 @@ data StaticShapeX sh where SZX :: StaticShapeX '[] (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) +deriving instance Show (StaticShapeX sh) type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where diff --git a/src/Fancy.hs b/src/Fancy.hs index 8019393..6b6d8d4 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -3,6 +3,7 @@ {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -11,6 +12,7 @@ module Fancy where import Control.Monad (forM_) import Control.Monad.ST +import Data.Coerce (coerce) import Data.Kind import Data.Proxy import Data.Type.Equality @@ -208,8 +210,10 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs -mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f + | not (checkBounds sh (knownShapeX @sh)) = + error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -- We need to be very careful here to ensure that neither 'sh' nor -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. | X.shapeSize sh == 0 = memptyArray sh @@ -224,6 +228,11 @@ mgenerate sh f forM_ (tail (X.enumShape sh)) $ \idx -> mvecsWrite sh idx (f idx) vecs mvecsFreeze sh vecs + where + checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool + checkBounds IZX SZX = True + checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' + checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' type Ranked :: Nat -> Type -> Type @@ -241,6 +250,70 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr + mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) + memptyArray i + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh (Mixed (Replicate n 'Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArray i + + mvecsNumElts (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = mvecsNumElts arr + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs + | Dict <- lemKnownReplicate (Proxy @n) + = mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' + => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs + | Dict <- lemKnownReplicate (Proxy @n) + = mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + +data SShape sh where + ShNil :: SShape '[] + ShCons :: SNat n -> SShape sh -> SShape (n : sh) +deriving instance Show (SShape sh) + +class KnownShape sh where knownShape :: SShape sh +instance KnownShape '[] where knownShape = ShNil +instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape + +-- instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where type IxR :: Nat -> Type @@ -253,4 +326,52 @@ data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh) +ixCvtXR :: IxX sh -> IxR (X.Rank sh) +ixCvtXR IZX = IZR +ixCvtXR (n ::@ sh) = n ::: ixCvtXR sh +ixCvtXR (n ::? sh) = n ::: ixCvtXR sh + +ixCvtRX :: IxR n -> IxX (Replicate n Nothing) +ixCvtRX IZR = IZX +ixCvtRX (n ::: sh) = n ::? ixCvtRX sh +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (knownNat @n) + where + go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go SZ = Refl + go (SS n) | Refl <- go n = Refl + +lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (knownNat @n) + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS n) | Refl <- go n = Refl + + +rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n +rshape (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemRankReplicate (Proxy @n) + = ixCvtXR (mshape arr) + +rindex :: GMixed a => Ranked n a -> IxR n -> a +rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) + +rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a +rewriteMixed Refl x = x + +rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a +rindexPartial (Ranked arr) idx + | Refl <- lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing) + = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) + (ixCvtRX idx)) + +rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a +rgenerate sh f + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemRankReplicate (Proxy @n) + = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) diff --git a/src/Nats.hs b/src/Nats.hs index a9ad47c..fdc090e 100644 --- a/src/Nats.hs +++ b/src/Nats.hs @@ -42,6 +42,10 @@ snatKnown :: SNat n -> Dict KnownNat n snatKnown SZ = Dict snatKnown (SS n) | Dict <- snatKnown n = Dict +type family n + m where + Z + m = m + S n + m = S (n + m) + type family GNat n where GNat Z = 0 GNat (S n) = 1 G.+ GNat n -- cgit v1.2.3-70-g09d2