aboutsummaryrefslogtreecommitdiff
path: root/src/Data/SNat/Peano.hs
blob: a5109fa1d02e34073af0a284b34f30cc7e10cea8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Data.SNat.Peano (
  -- * Singleton naturals
  Nat(..),
  SNat(SZ, SS),
  mkSNat, mkSNatFromGHC, NatFromGHC,
  withSomeSNat, withSomeSNat',
  fromSNat, fromSNat',

  -- * Computing with 'SNat' values
  type (+), snatAdd,
  type (-), snatSub,
  type (*), snatMul,
  Compare, SOrdering(..), snatCompare, OrdCond, type (<=), type (<), type (>=), type (>),

  -- * 'KnownNat'
  KnownNat(..),

  -- * Interoperate with GHC naturals
  pattern SNat', recoverGHC, GHCFromNat, lemGHCNatGHC, lemNatGHCNat,
) where

import Control.DeepSeq
import Data.Proxy
import Data.Type.Equality
import Numeric.Natural
import qualified GHC.TypeNats as GHC
import Unsafe.Coerce


-- | Type-level Peano naturals.
type data Nat = Z | S Nat

-- | A singleton for 'Nat'. The representation is just a 'Natural', not an
-- actual Peano natural / linked list.
newtype SNat (n :: Nat) = MkSNatUnsafe Natural

instance Show (SNat n) where
  showsPrec d (MkSNatUnsafe n) = showParen (d > 10) $
    showString ("mkSNat @" ++ show n)

-- these are vacuous because it's a singleton
instance Eq (SNat n) where _ == _ = True
instance Ord (SNat n) where compare _ _ = EQ

instance NFData (SNat n) where
  rnf (MkSNatUnsafe n) = rnf n

instance TestEquality SNat where
  testEquality (MkSNatUnsafe n) (MkSNatUnsafe m)
    | n == m = Just unsafeCoerceRefl
    | otherwise = Nothing

-- | The zero natural. Read this as: 'SZ :: SNat Z'
pattern SZ :: forall n. () => n ~ Z => SNat n
pattern SZ <- ((\(MkSNatUnsafe k) -> (# k, unsafeCoerceRefl @n @Z #))
               -> (# 0, Refl #))
  where SZ = MkSNatUnsafe 0

-- | /n/ plus one. Read this as: 'SS :: SNat n -> SNat (S n)'
pattern SS :: forall n. () => forall predn. S predn ~ n => SNat predn -> SNat n
pattern SS n <- (mkPredecessor -> Just (Predecessor n))
  where SS (MkSNatUnsafe n) = MkSNatUnsafe (n + 1)
-- A little experiment showed that mkPredecessor inlines sufficiently that no
-- Predecessor object every gets allocated. Let's hope this is also true in
-- more practical situations.

{-# COMPLETE SZ, SS #-}

data Predecessor n = forall predn. S predn ~ n => Predecessor (SNat predn)

mkPredecessor :: forall n. SNat n -> Maybe (Predecessor n)
mkPredecessor (MkSNatUnsafe 0) = Nothing
mkPredecessor (MkSNatUnsafe k) = Just (yolo (MkSNatUnsafe (k-1)))
  where
    yolo :: forall m predn. SNat predn -> Predecessor m
    yolo n | Refl <- unsafeCoerceRefl @(S predn) @m = Predecessor n

-- | Convert a GHC type-level 'GHC.Nat' to a type-level Peano natural. Because
-- this type family performs induction on a GHC 'GHC.Nat', which only works
-- sensibly if the 'GHC.Nat' is monomorphic, using this on a bare type variable
-- will probably be unsuccessful.
type family NatFromGHC ghcn where
  NatFromGHC 0 = Z
  NatFromGHC n = S (NatFromGHC (n GHC.- 1))

-- | Convenience function to create an 'SNat'. Use with @-XDataKinds@ and
-- @-XTypeApplications@ like:
--
-- >>> mkSNat @5
--
-- The 'GHC.KnownNat' constraint is automatically satisfied for any statically
-- known number. To construct an 'SNat' dynamically, you probably want
-- 'mkSNatFromGHC', or perhaps iterated 'SS'.
mkSNat :: forall ghcn. GHC.KnownNat ghcn => SNat (NatFromGHC ghcn)
mkSNat = MkSNatUnsafe (GHC.natVal (Proxy @ghcn))

-- | Convert a GHC 'GHC.SNat' to an 'SNat'. You can convert back using
-- the 'SNat'' pattern synonym, or more manually using 'recoverGHC'.
mkSNatFromGHC :: GHC.SNat ghcn -> SNat (NatFromGHC ghcn)
mkSNatFromGHC sn@GHC.SNat = MkSNatUnsafe (GHC.natVal sn)

-- | Dynamically create an 'SNat' from an untyped 'Natural'.
withSomeSNat :: Natural -> (forall n. SNat n -> r) -> r
withSomeSNat n k = k (MkSNatUnsafe n)

-- | Dynamically create an 'SNat' from an untyped 'Int'. Throws an exception if
-- the argument is negative.
withSomeSNat' :: Int -> (forall n. SNat n -> r) -> r
withSomeSNat' n k
  | n < 0 = error $ "withSomeSNat': " ++ show n ++ " is negative"
  | otherwise = k (MkSNatUnsafe (fromIntegral n))

-- | Get the untyped 'Natural' corresponding to the 'SNat'.
fromSNat :: SNat n -> Natural
fromSNat (MkSNatUnsafe n) = n

-- | Unsafe! If @n@ is out of range for @Int@, this will simply wrap, not throw
-- an error!
fromSNat' :: SNat n -> Int
fromSNat' (MkSNatUnsafe n) = fromIntegral n

-- | Convert a type-level Peano natural to a GHC type-level 'GHC.Nat'.
type family GHCFromNat n where
  GHCFromNat Z = 0
  GHCFromNat (S n) = 1 GHC.+ GHCFromNat n

-- | Convert an 'SNat' back to a GHC 'GHC.SNat'. If you use 'recoverGHC' after
-- 'mkSNatFromGHC', you will end up with an 'GHC.SNat (GHCFromNat (NatFromGHC
-- ghcn))'; use 'lemGHCToGHC' to rewrite that back to 'GHC.SNat ghcn'.
recoverGHC :: forall n. SNat n -> GHC.SNat (GHCFromNat n)
recoverGHC (MkSNatUnsafe n) =
  GHC.withSomeSNat n $ \(ghcn :: GHC.SNat m) ->
    unsafeCoerce @(GHC.SNat m) @(GHC.SNat (GHCFromNat n)) ghcn

-- | 'GHCFromNat' and 'NatFromGHC' are inverses (first half).
lemGHCNatGHC :: GHCFromNat (NatFromGHC ghcn) :~: ghcn
lemGHCNatGHC = unsafeCoerceRefl

-- | 'GHCFromNat' and 'NatFromGHC' are inverses (second half).
lemNatGHCNat :: NatFromGHC (GHCFromNat n) :~: n
lemNatGHCNat = unsafeCoerceRefl

pattern SNat' :: forall n. () => GHC.KnownNat (GHCFromNat n) => SNat n
pattern SNat' <- (recoverGHC -> GHC.SNat)
  where SNat' = case lemNatGHCNat @n of Refl -> mkSNat @(GHCFromNat n)
{-# COMPLETE SNat' #-}

-- | Add type-level Peano naturals.
type family n + m where
  Z + m = m
  S n + m = S (n + m)

-- | Add 'SNat's.
snatAdd :: SNat n -> SNat m -> SNat (n + m)
snatAdd (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n + m)

-- | Subtract type-level Peano naturals. Does not reduce if the result would be negative.
type family n - m where
  n - Z = n
  S n - S m = n - m

-- | Subtract 'SNat's. Returns 'Nothing' if the result would be negative.
snatSub :: SNat n -> SNat m -> Maybe (SNat (n - m))
snatSub (MkSNatUnsafe n) (MkSNatUnsafe m)
  | n >= m = Just (MkSNatUnsafe (n - m))
  | otherwise  = Nothing

-- | Multiply type-level Peano naturals.
type family n * m where
  Z * m = Z
  S n * m = m + n * m

-- | Multiply 'SNat's.
snatMul :: SNat n -> SNat m -> SNat (n * m)
snatMul (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n * m)

type family Compare n m where
  Compare Z Z = EQ
  Compare Z (S m) = LT
  Compare (S n) Z = GT
  Compare (S n) (S m) = Compare n m

data SOrdering n m where
  SLT :: Compare n m ~ LT => SOrdering n m
  SEQ :: Compare n n ~ EQ => SOrdering n n
  SGT :: Compare n m ~ GT => SOrdering n m

snatCompare :: SNat n -> SNat m -> SOrdering n m
snatCompare (MkSNatUnsafe @n n) (MkSNatUnsafe @m m) = case compare n m of
  LT | Refl <- unsafeCoerceRefl @(Compare n m) @LT ->
    SLT
  EQ | Refl <- unsafeCoerceRefl @n @m
     , Refl <- unsafeCoerceRefl @(Compare n n) @EQ ->
    SEQ
  GT | Refl <- unsafeCoerceRefl @(Compare n m) @GT ->
    SGT

type family OrdCond ord lt eq gt where
  OrdCond LT lt eq gt = lt
  OrdCond EQ lt eq gt = eq
  OrdCond GT lt eq gt = gt

type n <= m = OrdCond (Compare n m) True True False ~ True
type n < m = Compare n m ~ LT
type n >= m = OrdCond (Compare n m) False True True ~ True
type n > m = Compare n m ~ GT

-- | Pass an 'SNat' implicitly, in a constraint.
class KnownNat n where knownNat :: SNat n
instance KnownNat Z where knownNat = SZ
instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat

unsafeCoerceRefl :: forall a b. a :~: b
unsafeCoerceRefl = unsafeCoerce Refl