diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-11 00:11:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-15 11:06:40 +0100 |
commit | e6c20868375d2b7f6b31808844e1b48f78bca069 (patch) | |
tree | 5e3c3efa5c61eb11a28b486bccbbcac823a36614 /src/Data/SNat/Peano.hs | |
parent | c705bb4cf76d2e80f3e9ed900f901b697b378f79 (diff) |
WIP half-peano SNatspeano-snat
Diffstat (limited to 'src/Data/SNat/Peano.hs')
-rw-r--r-- | src/Data/SNat/Peano.hs | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/src/Data/SNat/Peano.hs b/src/Data/SNat/Peano.hs new file mode 100644 index 0000000..a5109fa --- /dev/null +++ b/src/Data/SNat/Peano.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.SNat.Peano ( + -- * Singleton naturals + Nat(..), + SNat(SZ, SS), + mkSNat, mkSNatFromGHC, NatFromGHC, + withSomeSNat, withSomeSNat', + fromSNat, fromSNat', + + -- * Computing with 'SNat' values + type (+), snatAdd, + type (-), snatSub, + type (*), snatMul, + Compare, SOrdering(..), snatCompare, OrdCond, type (<=), type (<), type (>=), type (>), + + -- * 'KnownNat' + KnownNat(..), + + -- * Interoperate with GHC naturals + pattern SNat', recoverGHC, GHCFromNat, lemGHCNatGHC, lemNatGHCNat, +) where + +import Control.DeepSeq +import Data.Proxy +import Data.Type.Equality +import Numeric.Natural +import qualified GHC.TypeNats as GHC +import Unsafe.Coerce + + +-- | Type-level Peano naturals. +type data Nat = Z | S Nat + +-- | A singleton for 'Nat'. The representation is just a 'Natural', not an +-- actual Peano natural / linked list. +newtype SNat (n :: Nat) = MkSNatUnsafe Natural + +instance Show (SNat n) where + showsPrec d (MkSNatUnsafe n) = showParen (d > 10) $ + showString ("mkSNat @" ++ show n) + +-- these are vacuous because it's a singleton +instance Eq (SNat n) where _ == _ = True +instance Ord (SNat n) where compare _ _ = EQ + +instance NFData (SNat n) where + rnf (MkSNatUnsafe n) = rnf n + +instance TestEquality SNat where + testEquality (MkSNatUnsafe n) (MkSNatUnsafe m) + | n == m = Just unsafeCoerceRefl + | otherwise = Nothing + +-- | The zero natural. Read this as: 'SZ :: SNat Z' +pattern SZ :: forall n. () => n ~ Z => SNat n +pattern SZ <- ((\(MkSNatUnsafe k) -> (# k, unsafeCoerceRefl @n @Z #)) + -> (# 0, Refl #)) + where SZ = MkSNatUnsafe 0 + +-- | /n/ plus one. Read this as: 'SS :: SNat n -> SNat (S n)' +pattern SS :: forall n. () => forall predn. S predn ~ n => SNat predn -> SNat n +pattern SS n <- (mkPredecessor -> Just (Predecessor n)) + where SS (MkSNatUnsafe n) = MkSNatUnsafe (n + 1) +-- A little experiment showed that mkPredecessor inlines sufficiently that no +-- Predecessor object every gets allocated. Let's hope this is also true in +-- more practical situations. + +{-# COMPLETE SZ, SS #-} + +data Predecessor n = forall predn. S predn ~ n => Predecessor (SNat predn) + +mkPredecessor :: forall n. SNat n -> Maybe (Predecessor n) +mkPredecessor (MkSNatUnsafe 0) = Nothing +mkPredecessor (MkSNatUnsafe k) = Just (yolo (MkSNatUnsafe (k-1))) + where + yolo :: forall m predn. SNat predn -> Predecessor m + yolo n | Refl <- unsafeCoerceRefl @(S predn) @m = Predecessor n + +-- | Convert a GHC type-level 'GHC.Nat' to a type-level Peano natural. Because +-- this type family performs induction on a GHC 'GHC.Nat', which only works +-- sensibly if the 'GHC.Nat' is monomorphic, using this on a bare type variable +-- will probably be unsuccessful. +type family NatFromGHC ghcn where + NatFromGHC 0 = Z + NatFromGHC n = S (NatFromGHC (n GHC.- 1)) + +-- | Convenience function to create an 'SNat'. Use with @-XDataKinds@ and +-- @-XTypeApplications@ like: +-- +-- >>> mkSNat @5 +-- +-- The 'GHC.KnownNat' constraint is automatically satisfied for any statically +-- known number. To construct an 'SNat' dynamically, you probably want +-- 'mkSNatFromGHC', or perhaps iterated 'SS'. +mkSNat :: forall ghcn. GHC.KnownNat ghcn => SNat (NatFromGHC ghcn) +mkSNat = MkSNatUnsafe (GHC.natVal (Proxy @ghcn)) + +-- | Convert a GHC 'GHC.SNat' to an 'SNat'. You can convert back using +-- the 'SNat'' pattern synonym, or more manually using 'recoverGHC'. +mkSNatFromGHC :: GHC.SNat ghcn -> SNat (NatFromGHC ghcn) +mkSNatFromGHC sn@GHC.SNat = MkSNatUnsafe (GHC.natVal sn) + +-- | Dynamically create an 'SNat' from an untyped 'Natural'. +withSomeSNat :: Natural -> (forall n. SNat n -> r) -> r +withSomeSNat n k = k (MkSNatUnsafe n) + +-- | Dynamically create an 'SNat' from an untyped 'Int'. Throws an exception if +-- the argument is negative. +withSomeSNat' :: Int -> (forall n. SNat n -> r) -> r +withSomeSNat' n k + | n < 0 = error $ "withSomeSNat': " ++ show n ++ " is negative" + | otherwise = k (MkSNatUnsafe (fromIntegral n)) + +-- | Get the untyped 'Natural' corresponding to the 'SNat'. +fromSNat :: SNat n -> Natural +fromSNat (MkSNatUnsafe n) = n + +-- | Unsafe! If @n@ is out of range for @Int@, this will simply wrap, not throw +-- an error! +fromSNat' :: SNat n -> Int +fromSNat' (MkSNatUnsafe n) = fromIntegral n + +-- | Convert a type-level Peano natural to a GHC type-level 'GHC.Nat'. +type family GHCFromNat n where + GHCFromNat Z = 0 + GHCFromNat (S n) = 1 GHC.+ GHCFromNat n + +-- | Convert an 'SNat' back to a GHC 'GHC.SNat'. If you use 'recoverGHC' after +-- 'mkSNatFromGHC', you will end up with an 'GHC.SNat (GHCFromNat (NatFromGHC +-- ghcn))'; use 'lemGHCToGHC' to rewrite that back to 'GHC.SNat ghcn'. +recoverGHC :: forall n. SNat n -> GHC.SNat (GHCFromNat n) +recoverGHC (MkSNatUnsafe n) = + GHC.withSomeSNat n $ \(ghcn :: GHC.SNat m) -> + unsafeCoerce @(GHC.SNat m) @(GHC.SNat (GHCFromNat n)) ghcn + +-- | 'GHCFromNat' and 'NatFromGHC' are inverses (first half). +lemGHCNatGHC :: GHCFromNat (NatFromGHC ghcn) :~: ghcn +lemGHCNatGHC = unsafeCoerceRefl + +-- | 'GHCFromNat' and 'NatFromGHC' are inverses (second half). +lemNatGHCNat :: NatFromGHC (GHCFromNat n) :~: n +lemNatGHCNat = unsafeCoerceRefl + +pattern SNat' :: forall n. () => GHC.KnownNat (GHCFromNat n) => SNat n +pattern SNat' <- (recoverGHC -> GHC.SNat) + where SNat' = case lemNatGHCNat @n of Refl -> mkSNat @(GHCFromNat n) +{-# COMPLETE SNat' #-} + +-- | Add type-level Peano naturals. +type family n + m where + Z + m = m + S n + m = S (n + m) + +-- | Add 'SNat's. +snatAdd :: SNat n -> SNat m -> SNat (n + m) +snatAdd (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n + m) + +-- | Subtract type-level Peano naturals. Does not reduce if the result would be negative. +type family n - m where + n - Z = n + S n - S m = n - m + +-- | Subtract 'SNat's. Returns 'Nothing' if the result would be negative. +snatSub :: SNat n -> SNat m -> Maybe (SNat (n - m)) +snatSub (MkSNatUnsafe n) (MkSNatUnsafe m) + | n >= m = Just (MkSNatUnsafe (n - m)) + | otherwise = Nothing + +-- | Multiply type-level Peano naturals. +type family n * m where + Z * m = Z + S n * m = m + n * m + +-- | Multiply 'SNat's. +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n * m) + +type family Compare n m where + Compare Z Z = EQ + Compare Z (S m) = LT + Compare (S n) Z = GT + Compare (S n) (S m) = Compare n m + +data SOrdering n m where + SLT :: Compare n m ~ LT => SOrdering n m + SEQ :: Compare n n ~ EQ => SOrdering n n + SGT :: Compare n m ~ GT => SOrdering n m + +snatCompare :: SNat n -> SNat m -> SOrdering n m +snatCompare (MkSNatUnsafe @n n) (MkSNatUnsafe @m m) = case compare n m of + LT | Refl <- unsafeCoerceRefl @(Compare n m) @LT -> + SLT + EQ | Refl <- unsafeCoerceRefl @n @m + , Refl <- unsafeCoerceRefl @(Compare n n) @EQ -> + SEQ + GT | Refl <- unsafeCoerceRefl @(Compare n m) @GT -> + SGT + +type family OrdCond ord lt eq gt where + OrdCond LT lt eq gt = lt + OrdCond EQ lt eq gt = eq + OrdCond GT lt eq gt = gt + +type n <= m = OrdCond (Compare n m) True True False ~ True +type n < m = Compare n m ~ LT +type n >= m = OrdCond (Compare n m) False True True ~ True +type n > m = Compare n m ~ GT + +-- | Pass an 'SNat' implicitly, in a constraint. +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +unsafeCoerceRefl :: forall a b. a :~: b +unsafeCoerceRefl = unsafeCoerce Refl |