summaryrefslogtreecommitdiff
path: root/src/Fancy.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-27 22:58:51 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-27 22:58:51 +0100
commitae113c0249f3fe8be7df345081b1b51451cd3fdf (patch)
treefaf241d19fec975f7d77a4486505a57ed4cc0b8b /src/Fancy.hs
parentcb31e179971293c519a530d8ce8ccc004458b1c4 (diff)
Ranked interface
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r--src/Fancy.hs123
1 files changed, 122 insertions, 1 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs
index 8019393..6b6d8d4 100644
--- a/src/Fancy.hs
+++ b/src/Fancy.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
@@ -11,6 +12,7 @@ module Fancy where
import Control.Monad (forM_)
import Control.Monad.ST
+import Data.Coerce (coerce)
import Data.Kind
import Data.Proxy
import Data.Type.Equality
@@ -208,8 +210,10 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
-mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
-- We need to be very careful here to ensure that neither 'sh' nor
-- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
| X.shapeSize sh == 0 = memptyArray sh
@@ -224,6 +228,11 @@ mgenerate sh f
forM_ (tail (X.enumShape sh)) $ \idx ->
mvecsWrite sh idx (f idx) vecs
mvecsFreeze sh vecs
+ where
+ checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+ checkBounds IZX SZX = True
+ checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
+ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
type Ranked :: Nat -> Type -> Type
@@ -241,6 +250,70 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
+ mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
+ memptyArray i
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh (Mixed (Replicate n 'Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArray i
+
+ mvecsNumElts (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsNumElts arr
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
+ => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+
+data SShape sh where
+ ShNil :: SShape '[]
+ ShCons :: SNat n -> SShape sh -> SShape (n : sh)
+deriving instance Show (SShape sh)
+
+class KnownShape sh where knownShape :: SShape sh
+instance KnownShape '[] where knownShape = ShNil
+instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
+
+-- instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where
type IxR :: Nat -> Type
@@ -253,4 +326,52 @@ data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
+ixCvtXR :: IxX sh -> IxR (X.Rank sh)
+ixCvtXR IZX = IZR
+ixCvtXR (n ::@ sh) = n ::: ixCvtXR sh
+ixCvtXR (n ::? sh) = n ::: ixCvtXR sh
+
+ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
+ixCvtRX IZR = IZX
+ixCvtRX (n ::: sh) = n ::? ixCvtRX sh
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (knownNat @n)
+ where
+ go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go SZ = Refl
+ go (SS n) | Refl <- go n = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (knownNat @n)
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS n) | Refl <- go n = Refl
+
+
+rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n
+rshape (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemRankReplicate (Proxy @n)
+ = ixCvtXR (mshape arr)
+
+rindex :: GMixed a => Ranked n a -> IxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
+
+rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
+rewriteMixed Refl x = x
+
+rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a
+rindexPartial (Ranked arr) idx
+ | Refl <- lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)
+ = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
+ (ixCvtRX idx))
+
+rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate sh f
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- lemRankReplicate (Proxy @n)
+ = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))