From 43ddff2e7f1e9f4d8855f573384e26b63d34f697 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 14 May 2024 23:30:53 +0200 Subject: WIP GHC nats --- src/Data/Array/Nested/Internal.hs | 221 +++++++++++++++++++++++++------------- 1 file changed, 145 insertions(+), 76 deletions(-) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index d041aff..54b567a 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -20,16 +20,51 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# OPTIONS_GHC -Wno-unused-imports #-} {-| TODO: * We should be more consistent in whether functions take a 'StaticShX' argument or a 'KnownShapeX' constraint. -* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point - being that we need to do induction over the former, but the latter need to be - able to get large. +* Mikolaj wants these: + + About your wishlist of operations: these are already there + + OR.index + OR.append + OR.transpose + + These can be easily lifted from the definition for XArray (5min work): + + OR.scalar + OR.unScalar + OR.constant + + These should not be hard: + + OR.fromList + ORB.toList . OR.unravel + OR.ravel . ORB.fromList + OR.slice + OR.rev + OR.reshape + + though it's a bit unfortunate that we end up needing toList. Looking in + horde-ad I see that you seem to need them to do certain operations in Haskell + that orthotope doesn't support? + + For this one we'll need to see to what extent you really need it, and what API + you'd need precisely: + + OR.rerank + + and for these we have an API design question: + + OR.toVector + OR.fromVector -} @@ -51,10 +86,10 @@ import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..)) import qualified Data.Array.Mixed as X -import Data.INat -- Invariant in the API @@ -91,34 +126,71 @@ import Data.INat type family Replicate n a where - Replicate Z a = '[] - Replicate (S n) a = a : Replicate n a + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs -lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerce Refl + +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where - go :: SINat m -> StaticShX (Replicate m Nothing) + go :: SNat m -> StaticShX (Replicate m Nothing) go SZ = ZKSX - go (SS n) = () :!$? go n + go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n -lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (inatSing @n) +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (natSing @n) where - go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl - go (SS n) | Refl <- go n = Refl - -lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (inatSing @n) + go (SS (n :: SNat nm1)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 + , Refl <- go n + = Refl + +lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (natSing @n) where - go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl - go (SS n) | Refl <- go n = Refl + go (SS (n :: SNat n'm1)) + | Refl <- lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (lemReplicateSucc @a @(n'm1 + m)) shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') shAppSplit _ ZKSX idx = (ZSX, idx) @@ -575,18 +647,15 @@ deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed s -- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'INat'. +-- represented on the type level as a 'Nat'. -- -- Valid elements of a ranked arrays are described by the 'Elt' type class. -- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are -- supported (and are represented as a single, flattened, struct-of-arrays -- array internally). -- --- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a --- type-level natural that supports induction. --- -- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: INat -> Type -> Type +type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) @@ -616,7 +685,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. -instance (Elt a, KnownINat n) => Elt (Ranked n a) where +instance (Elt a, KnownNat n) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) @@ -848,37 +917,37 @@ rewriteMixed Refl x = x -- ====== API OF RANKED ARRAYS ====== -- -arithPromoteRanked :: forall n a. KnownINat n +arithPromoteRanked :: forall n a. KnownNat n => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce -arithPromoteRanked2 :: forall n a. KnownINat n +arithPromoteRanked2 :: forall n a. KnownNat n => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce -instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where +instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger n = case inatSing @n of + fromInteger n = case natSing @n of SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) - SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ - \Rank non-zero, use explicit mconstant" + _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ + \Rank non-zero, use explicit mconstant" -- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) -deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) +deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int) +deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double) type role ListR nominal representational -type ListR :: INat -> Type -> Type +type ListR :: Nat -> Type -> Type data ListR n i where - ZR :: ListR Z i - (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i deriving instance Show i => Show (ListR n i) deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) @@ -892,23 +961,23 @@ listRToList :: ListR n i -> [i] listRToList ZR = [] listRToList (i ::: is) = i : listRToList is -knownListR :: ListR n i -> Dict KnownINat n +knownListR :: ListR n i -> Dict KnownNat n knownListR ZR = Dict knownListR (_ ::: l) | Dict <- knownListR l = Dict -- | An index into a rank-typed array. type role IxR nominal representational -type IxR :: INat -> Type -> Type +type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) deriving (Show, Eq, Ord) deriving newtype (Functor, Foldable) -pattern ZIR :: forall n i. () => n ~ Z => IxR n i +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR pattern (:.:) :: forall {n1} {i}. - forall n. (S n ~ n1) + forall n. (1 + n ~ n1) => i -> IxR n i -> IxR n1 i pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) where i :.: IxR sh = IxR (i ::: sh) @@ -916,30 +985,30 @@ pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) infixr 3 :.: data UnconsIxRRes i n1 = - forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i + forall n. (1 + n ~ n1) => UnconsIxRRes (IxR n i) i unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1) unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i) unconsIxR (IxR ZR) = Nothing type IIxR n = IxR n Int -knownIxR :: IxR n i -> Dict KnownINat n +knownIxR :: IxR n i -> Dict KnownNat n knownIxR (IxR sh) = knownListR sh type role ShR nominal representational -type ShR :: INat -> Type -> Type +type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) deriving (Show, Eq, Ord) deriving newtype (Functor, Foldable) type IShR n = ShR n Int -pattern ZSR :: forall n i. () => n ~ Z => ShR n i +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR = ShR ZR pattern (:$:) :: forall {n1} {i}. - forall n. (S n ~ n1) + forall n. (1 + n ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) where i :$: (ShR sh) = ShR (i ::: sh) @@ -947,15 +1016,15 @@ pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) infixr 3 :$: data UnconsShRRes i n1 = - forall n. S n ~ n1 => UnconsShRRes (ShR n i) i + forall n. 1 + n ~ n1 => UnconsShRRes (ShR n i) i unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1) unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i) unconsShR (ShR ZR) = Nothing -knownShR :: ShR n i -> Dict KnownINat n +knownShR :: ShR n i -> Dict KnownNat n knownShR (ShR sh) = knownListR sh -zeroIxR :: SINat n -> IIxR n +zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR zeroIxR (SS n) = 0 :.: zeroIxR n @@ -982,7 +1051,7 @@ shapeSizeR ZSR = 1 shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n +rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) @@ -991,7 +1060,7 @@ rshape (Ranked arr) rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a +rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) @@ -1007,7 +1076,7 @@ rgenerate sh f = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) -- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. (KnownINat n2, Elt a) +rlift :: forall n1 n2 a. (KnownNat n2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) @@ -1015,39 +1084,39 @@ rlift f (Ranked arr) = Ranked (mlift f arr) rsumOuter1P :: forall n a. - (Storable a, Num a, KnownINat n) - => Ranked (S n) (Primitive a) -> Ranked n (Primitive a) + (Storable a, Num a, KnownNat n, 1 <= n) + => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a) rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked - . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) - . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) + . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a)) + . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing)) + . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a) $ arr rsumOuter1 :: forall n a. - (Storable a, Num a, PrimElt a, KnownINat n) - => Ranked (S n) a -> Ranked n a + (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n) + => Ranked n a -> Ranked (n - 1) a rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive -rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a +rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mtranspose perm arr) -rappend :: forall n a. (KnownINat n, Elt a) - => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a +rappend :: forall n a. (KnownNat n, Elt a, 1 <= n) + => Ranked n a -> Ranked n a -> Ranked n a rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend -rscalar :: Elt a => a -> Ranked I0 a +rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) -rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mfromVectorP (shCvtRX sh) v) -rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a +rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a @@ -1056,39 +1125,39 @@ rtoVectorP = coerce mtoVectorP rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a +rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a rfromList1 l | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mfromList1 (coerce l)) -rfromList :: Elt a => NonEmpty a -> Ranked I1 a +rfromList :: Elt a => NonEmpty a -> Ranked 1 a rfromList = Ranked . mfromList1 . fmap mscalar -rtoList :: Elt a => Ranked (S n) a -> [Ranked n a] +rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a] rtoList (Ranked arr) = coerce (mtoList1 arr) -rtoList1 :: Elt a => Ranked I1 a -> [a] +rtoList1 :: Elt a => Ranked 1 a -> [a] rtoList1 = map runScalar . rtoList -runScalar :: Elt a => Ranked I0 a -> a +runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) rconstantP sh x | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mconstantP (shCvtRX sh) x) -rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a) +rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> a -> Ranked n a rconstant sh x = coerce fromPrimitive (rconstantP sh x) -rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a +rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a rslice ivs = rlift $ \_ -> X.slice ivs -rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a +rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a rrev1 = rlift $ \_ -> X.rev1 -rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a) +rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) => IShR n' -> Ranked n a -> Ranked n' a rreshape sh' (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) -- cgit v1.2.3-70-g09d2