diff options
| -rw-r--r-- | src/Array.hs | 1 | ||||
| -rw-r--r-- | src/Fancy.hs | 123 | ||||
| -rw-r--r-- | src/Nats.hs | 4 | 
3 files changed, 127 insertions, 1 deletions
| diff --git a/src/Array.hs b/src/Array.hs index 5140eaf..29806d4 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -43,6 +43,7 @@ data StaticShapeX sh where    SZX :: StaticShapeX '[]    (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)    (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) +deriving instance Show (StaticShapeX sh)  type KnownShapeX :: [Maybe Nat] -> Constraint  class KnownShapeX sh where diff --git a/src/Fancy.hs b/src/Fancy.hs index 8019393..6b6d8d4 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -3,6 +3,7 @@  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-} @@ -11,6 +12,7 @@ module Fancy where  import Control.Monad (forM_)  import Control.Monad.ST +import Data.Coerce (coerce)  import Data.Kind  import Data.Proxy  import Data.Type.Equality @@ -208,8 +210,10 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where    mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs -mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a  mgenerate sh f +  | not (checkBounds sh (knownShapeX @sh)) = +      error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)    -- We need to be very careful here to ensure that neither 'sh' nor    -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.    | X.shapeSize sh == 0 = memptyArray sh @@ -224,6 +228,11 @@ mgenerate sh f                    forM_ (tail (X.enumShape sh)) $ \idx ->                      mvecsWrite sh idx (f idx) vecs                    mvecsFreeze sh vecs +  where +    checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool +    checkBounds IZX SZX = True +    checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' +    checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'  type Ranked :: Nat -> Type -> Type @@ -241,6 +250,70 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe  instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where    mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr +  mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) + +  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) +  mindexPartial (M_Ranked arr) i +    | Dict <- lemKnownReplicate (Proxy @n) +    = coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $ +        mindexPartial arr i + +  memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) +  memptyArray i +    | Dict <- lemKnownReplicate (Proxy @n) +    = coerce @(Mixed sh (Mixed (Replicate n 'Nothing) a)) @(Mixed sh (Ranked n a)) $ +        memptyArray i + +  mvecsNumElts (Ranked arr) +    | Dict <- lemKnownReplicate (Proxy @n) +    = mvecsNumElts arr + +  mvecsUnsafeNew idx (Ranked arr) +    | Dict <- lemKnownReplicate (Proxy @n) +    = MV_Ranked <$> mvecsUnsafeNew idx arr + +  mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () +  mvecsWrite sh idx (Ranked arr) vecs +    | Dict <- lemKnownReplicate (Proxy @n) +    = mvecsWrite sh idx arr +        (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) +           vecs) + +  mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' +                    => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) +                    -> MixedVecs s (sh ++ sh') (Ranked n a) +                    -> ST s () +  mvecsWritePartial sh idx arr vecs +    | Dict <- lemKnownReplicate (Proxy @n) +    = mvecsWritePartial sh idx +        (coerce @(Mixed sh' (Ranked n a)) +                @(Mixed sh' (Mixed (Replicate n Nothing) a)) +           arr) +        (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) +                @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) +           vecs) + +  mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) +  mvecsFreeze sh vecs +    | Dict <- lemKnownReplicate (Proxy @n) +    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) +             @(Mixed sh (Ranked n a)) +        <$> mvecsFreeze sh +              (coerce @(MixedVecs s sh (Ranked n a)) +                      @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) +                      vecs) + + +data SShape sh where +  ShNil :: SShape '[] +  ShCons :: SNat n -> SShape sh -> SShape (n : sh) +deriving instance Show (SShape sh) + +class KnownShape sh where knownShape :: SShape sh +instance KnownShape '[] where knownShape = ShNil +instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape + +-- instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where  type IxR :: Nat -> Type @@ -253,4 +326,52 @@ data IxS sh where    IZS :: IxS '[]    (::$) :: Int -> IxS sh -> IxS (n : sh) +ixCvtXR :: IxX sh -> IxR (X.Rank sh) +ixCvtXR IZX = IZR +ixCvtXR (n ::@ sh) = n ::: ixCvtXR sh +ixCvtXR (n ::? sh) = n ::: ixCvtXR sh + +ixCvtRX :: IxR n -> IxX (Replicate n Nothing) +ixCvtRX IZR = IZX +ixCvtRX (n ::: sh) = n ::? ixCvtRX sh + +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (knownNat @n) +  where +    go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m +    go SZ = Refl +    go (SS n) | Refl <- go n = Refl + +lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a +                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (knownNat @n) +  where +    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a +    go SZ = Refl +    go (SS n) | Refl <- go n = Refl + + +rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n +rshape (Ranked arr) +  | Dict <- lemKnownReplicate (Proxy @n) +  , Refl <- lemRankReplicate (Proxy @n) +  = ixCvtXR (mshape arr) + +rindex :: GMixed a => Ranked n a -> IxR n -> a +rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) + +rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a +rewriteMixed Refl x = x + +rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a +rindexPartial (Ranked arr) idx +  | Refl <- lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing) +  = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) +              (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) +              (ixCvtRX idx)) +rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a +rgenerate sh f +  | Dict <- lemKnownReplicate (Proxy @n) +  , Refl <- lemRankReplicate (Proxy @n) +  = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) diff --git a/src/Nats.hs b/src/Nats.hs index a9ad47c..fdc090e 100644 --- a/src/Nats.hs +++ b/src/Nats.hs @@ -42,6 +42,10 @@ snatKnown :: SNat n -> Dict KnownNat n  snatKnown SZ = Dict  snatKnown (SS n) | Dict <- snatKnown n = Dict +type family n + m where +  Z + m = m +  S n + m = S (n + m) +  type family GNat n where    GNat Z = 0    GNat (S n) = 1 G.+ GNat n | 
