aboutsummaryrefslogtreecommitdiff
path: root/src/Data/SNat
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-11 00:11:53 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-15 11:06:40 +0100
commite6c20868375d2b7f6b31808844e1b48f78bca069 (patch)
tree5e3c3efa5c61eb11a28b486bccbbcac823a36614 /src/Data/SNat
parentc705bb4cf76d2e80f3e9ed900f901b697b378f79 (diff)
WIP half-peano SNatspeano-snat
Diffstat (limited to 'src/Data/SNat')
-rw-r--r--src/Data/SNat/Peano.hs232
1 files changed, 232 insertions, 0 deletions
diff --git a/src/Data/SNat/Peano.hs b/src/Data/SNat/Peano.hs
new file mode 100644
index 0000000..a5109fa
--- /dev/null
+++ b/src/Data/SNat/Peano.hs
@@ -0,0 +1,232 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.SNat.Peano (
+ -- * Singleton naturals
+ Nat(..),
+ SNat(SZ, SS),
+ mkSNat, mkSNatFromGHC, NatFromGHC,
+ withSomeSNat, withSomeSNat',
+ fromSNat, fromSNat',
+
+ -- * Computing with 'SNat' values
+ type (+), snatAdd,
+ type (-), snatSub,
+ type (*), snatMul,
+ Compare, SOrdering(..), snatCompare, OrdCond, type (<=), type (<), type (>=), type (>),
+
+ -- * 'KnownNat'
+ KnownNat(..),
+
+ -- * Interoperate with GHC naturals
+ pattern SNat', recoverGHC, GHCFromNat, lemGHCNatGHC, lemNatGHCNat,
+) where
+
+import Control.DeepSeq
+import Data.Proxy
+import Data.Type.Equality
+import Numeric.Natural
+import qualified GHC.TypeNats as GHC
+import Unsafe.Coerce
+
+
+-- | Type-level Peano naturals.
+type data Nat = Z | S Nat
+
+-- | A singleton for 'Nat'. The representation is just a 'Natural', not an
+-- actual Peano natural / linked list.
+newtype SNat (n :: Nat) = MkSNatUnsafe Natural
+
+instance Show (SNat n) where
+ showsPrec d (MkSNatUnsafe n) = showParen (d > 10) $
+ showString ("mkSNat @" ++ show n)
+
+-- these are vacuous because it's a singleton
+instance Eq (SNat n) where _ == _ = True
+instance Ord (SNat n) where compare _ _ = EQ
+
+instance NFData (SNat n) where
+ rnf (MkSNatUnsafe n) = rnf n
+
+instance TestEquality SNat where
+ testEquality (MkSNatUnsafe n) (MkSNatUnsafe m)
+ | n == m = Just unsafeCoerceRefl
+ | otherwise = Nothing
+
+-- | The zero natural. Read this as: 'SZ :: SNat Z'
+pattern SZ :: forall n. () => n ~ Z => SNat n
+pattern SZ <- ((\(MkSNatUnsafe k) -> (# k, unsafeCoerceRefl @n @Z #))
+ -> (# 0, Refl #))
+ where SZ = MkSNatUnsafe 0
+
+-- | /n/ plus one. Read this as: 'SS :: SNat n -> SNat (S n)'
+pattern SS :: forall n. () => forall predn. S predn ~ n => SNat predn -> SNat n
+pattern SS n <- (mkPredecessor -> Just (Predecessor n))
+ where SS (MkSNatUnsafe n) = MkSNatUnsafe (n + 1)
+-- A little experiment showed that mkPredecessor inlines sufficiently that no
+-- Predecessor object every gets allocated. Let's hope this is also true in
+-- more practical situations.
+
+{-# COMPLETE SZ, SS #-}
+
+data Predecessor n = forall predn. S predn ~ n => Predecessor (SNat predn)
+
+mkPredecessor :: forall n. SNat n -> Maybe (Predecessor n)
+mkPredecessor (MkSNatUnsafe 0) = Nothing
+mkPredecessor (MkSNatUnsafe k) = Just (yolo (MkSNatUnsafe (k-1)))
+ where
+ yolo :: forall m predn. SNat predn -> Predecessor m
+ yolo n | Refl <- unsafeCoerceRefl @(S predn) @m = Predecessor n
+
+-- | Convert a GHC type-level 'GHC.Nat' to a type-level Peano natural. Because
+-- this type family performs induction on a GHC 'GHC.Nat', which only works
+-- sensibly if the 'GHC.Nat' is monomorphic, using this on a bare type variable
+-- will probably be unsuccessful.
+type family NatFromGHC ghcn where
+ NatFromGHC 0 = Z
+ NatFromGHC n = S (NatFromGHC (n GHC.- 1))
+
+-- | Convenience function to create an 'SNat'. Use with @-XDataKinds@ and
+-- @-XTypeApplications@ like:
+--
+-- >>> mkSNat @5
+--
+-- The 'GHC.KnownNat' constraint is automatically satisfied for any statically
+-- known number. To construct an 'SNat' dynamically, you probably want
+-- 'mkSNatFromGHC', or perhaps iterated 'SS'.
+mkSNat :: forall ghcn. GHC.KnownNat ghcn => SNat (NatFromGHC ghcn)
+mkSNat = MkSNatUnsafe (GHC.natVal (Proxy @ghcn))
+
+-- | Convert a GHC 'GHC.SNat' to an 'SNat'. You can convert back using
+-- the 'SNat'' pattern synonym, or more manually using 'recoverGHC'.
+mkSNatFromGHC :: GHC.SNat ghcn -> SNat (NatFromGHC ghcn)
+mkSNatFromGHC sn@GHC.SNat = MkSNatUnsafe (GHC.natVal sn)
+
+-- | Dynamically create an 'SNat' from an untyped 'Natural'.
+withSomeSNat :: Natural -> (forall n. SNat n -> r) -> r
+withSomeSNat n k = k (MkSNatUnsafe n)
+
+-- | Dynamically create an 'SNat' from an untyped 'Int'. Throws an exception if
+-- the argument is negative.
+withSomeSNat' :: Int -> (forall n. SNat n -> r) -> r
+withSomeSNat' n k
+ | n < 0 = error $ "withSomeSNat': " ++ show n ++ " is negative"
+ | otherwise = k (MkSNatUnsafe (fromIntegral n))
+
+-- | Get the untyped 'Natural' corresponding to the 'SNat'.
+fromSNat :: SNat n -> Natural
+fromSNat (MkSNatUnsafe n) = n
+
+-- | Unsafe! If @n@ is out of range for @Int@, this will simply wrap, not throw
+-- an error!
+fromSNat' :: SNat n -> Int
+fromSNat' (MkSNatUnsafe n) = fromIntegral n
+
+-- | Convert a type-level Peano natural to a GHC type-level 'GHC.Nat'.
+type family GHCFromNat n where
+ GHCFromNat Z = 0
+ GHCFromNat (S n) = 1 GHC.+ GHCFromNat n
+
+-- | Convert an 'SNat' back to a GHC 'GHC.SNat'. If you use 'recoverGHC' after
+-- 'mkSNatFromGHC', you will end up with an 'GHC.SNat (GHCFromNat (NatFromGHC
+-- ghcn))'; use 'lemGHCToGHC' to rewrite that back to 'GHC.SNat ghcn'.
+recoverGHC :: forall n. SNat n -> GHC.SNat (GHCFromNat n)
+recoverGHC (MkSNatUnsafe n) =
+ GHC.withSomeSNat n $ \(ghcn :: GHC.SNat m) ->
+ unsafeCoerce @(GHC.SNat m) @(GHC.SNat (GHCFromNat n)) ghcn
+
+-- | 'GHCFromNat' and 'NatFromGHC' are inverses (first half).
+lemGHCNatGHC :: GHCFromNat (NatFromGHC ghcn) :~: ghcn
+lemGHCNatGHC = unsafeCoerceRefl
+
+-- | 'GHCFromNat' and 'NatFromGHC' are inverses (second half).
+lemNatGHCNat :: NatFromGHC (GHCFromNat n) :~: n
+lemNatGHCNat = unsafeCoerceRefl
+
+pattern SNat' :: forall n. () => GHC.KnownNat (GHCFromNat n) => SNat n
+pattern SNat' <- (recoverGHC -> GHC.SNat)
+ where SNat' = case lemNatGHCNat @n of Refl -> mkSNat @(GHCFromNat n)
+{-# COMPLETE SNat' #-}
+
+-- | Add type-level Peano naturals.
+type family n + m where
+ Z + m = m
+ S n + m = S (n + m)
+
+-- | Add 'SNat's.
+snatAdd :: SNat n -> SNat m -> SNat (n + m)
+snatAdd (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n + m)
+
+-- | Subtract type-level Peano naturals. Does not reduce if the result would be negative.
+type family n - m where
+ n - Z = n
+ S n - S m = n - m
+
+-- | Subtract 'SNat's. Returns 'Nothing' if the result would be negative.
+snatSub :: SNat n -> SNat m -> Maybe (SNat (n - m))
+snatSub (MkSNatUnsafe n) (MkSNatUnsafe m)
+ | n >= m = Just (MkSNatUnsafe (n - m))
+ | otherwise = Nothing
+
+-- | Multiply type-level Peano naturals.
+type family n * m where
+ Z * m = Z
+ S n * m = m + n * m
+
+-- | Multiply 'SNat's.
+snatMul :: SNat n -> SNat m -> SNat (n * m)
+snatMul (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n * m)
+
+type family Compare n m where
+ Compare Z Z = EQ
+ Compare Z (S m) = LT
+ Compare (S n) Z = GT
+ Compare (S n) (S m) = Compare n m
+
+data SOrdering n m where
+ SLT :: Compare n m ~ LT => SOrdering n m
+ SEQ :: Compare n n ~ EQ => SOrdering n n
+ SGT :: Compare n m ~ GT => SOrdering n m
+
+snatCompare :: SNat n -> SNat m -> SOrdering n m
+snatCompare (MkSNatUnsafe @n n) (MkSNatUnsafe @m m) = case compare n m of
+ LT | Refl <- unsafeCoerceRefl @(Compare n m) @LT ->
+ SLT
+ EQ | Refl <- unsafeCoerceRefl @n @m
+ , Refl <- unsafeCoerceRefl @(Compare n n) @EQ ->
+ SEQ
+ GT | Refl <- unsafeCoerceRefl @(Compare n m) @GT ->
+ SGT
+
+type family OrdCond ord lt eq gt where
+ OrdCond LT lt eq gt = lt
+ OrdCond EQ lt eq gt = eq
+ OrdCond GT lt eq gt = gt
+
+type n <= m = OrdCond (Compare n m) True True False ~ True
+type n < m = Compare n m ~ LT
+type n >= m = OrdCond (Compare n m) False True True ~ True
+type n > m = Compare n m ~ GT
+
+-- | Pass an 'SNat' implicitly, in a constraint.
+class KnownNat n where knownNat :: SNat n
+instance KnownNat Z where knownNat = SZ
+instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+
+unsafeCoerceRefl :: forall a b. a :~: b
+unsafeCoerceRefl = unsafeCoerce Refl