aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-11 00:11:53 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-15 11:06:40 +0100
commite6c20868375d2b7f6b31808844e1b48f78bca069 (patch)
tree5e3c3efa5c61eb11a28b486bccbbcac823a36614
parentc705bb4cf76d2e80f3e9ed900f901b697b378f79 (diff)
WIP half-peano SNatspeano-snat
-rw-r--r--ox-arrays.cabal1
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs129
-rw-r--r--src/Data/Array/Mixed/Permutation.hs88
-rw-r--r--src/Data/Array/Mixed/Shape.hs41
-rw-r--r--src/Data/Array/Mixed/Types.hs67
-rw-r--r--src/Data/SNat/Peano.hs232
6 files changed, 358 insertions, 200 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 1253956..376107b 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -35,6 +35,7 @@ library
Data.Array.Nested.Internal.Ranked
Data.Array.Nested.Internal.Shape
Data.Array.Nested.Internal.Shaped
+ Data.SNat.Peano
if flag(trace-wrappers)
exposed-modules:
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index a24efd6..dd4d8e7 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -24,14 +24,13 @@ import Foreign.C.Types
import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr
import Foreign.Storable (Storable(sizeOf), peek, poke)
-import GHC.TypeLits
import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
import System.IO.Unsafe
+import Data.SNat.Peano
import Data.Array.Mixed.Internal.Arith.Foreign
import Data.Array.Mixed.Internal.Arith.Lists
-import Data.Array.Mixed.Types (fromSNat')
-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
@@ -41,8 +40,8 @@ import Data.Array.Mixed.Types (fromSNat')
liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
-> (VS.Vector a -> VS.Vector b)
- -> RS.Array n a -> RS.Array n b
-liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) b
+liftVEltwise1 SNat' f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| Just (blockOff, blockSz) <- stridesDense sh offset strides =
let vec' = f (VS.slice blockOff blockSz vec)
in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
@@ -52,8 +51,8 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
liftVEltwise2 :: (Storable a, Storable b, Storable c)
=> SNat n
-> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c)
- -> RS.Array n a -> RS.Array n b -> RS.Array n c
-liftVEltwise2 SNat f
+ -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) b -> RS.Array (GHCFromNat n) c
+liftVEltwise2 SNat' f
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
| sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
@@ -172,8 +171,8 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> RS.Array (n + 1) a -> RS.Array n a
-vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
+ -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a
+vectorRedInnerOp sn@SNat' valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
| null sh = error "unreachable"
| last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0])
| any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty))
@@ -210,7 +209,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)
- TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
+ TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(TypeNats.SNat :: TypeNats.SNat lenFm1) ->
RS.stretch (init sh) -- replicate to original shape
. RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated
. RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions
@@ -226,7 +225,7 @@ vectorRedFullOp :: forall a b n. (Num a, Storable a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
- -> RS.Array n a -> a
+ -> RS.Array (GHCFromNat n) a -> a
vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec)))
| null sh = vec VS.! offset -- 0D array has one element
| any (<= 0) sh = 0
@@ -309,12 +308,12 @@ vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
=> SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
- -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication
+ -> (SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a) -- ^ elementwise multiplication
-> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
- -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
-vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
+ -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a
+vectorDotprodInnerOp sn@SNat' valconv ptrconv fmul fscale fred fdotinner
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
| null sh1 || null sh2 = error "unreachable"
@@ -344,7 +343,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
- RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
+ RS.fromVector @_ @(GHCFromNat n) (init sh1) <$> VS.unsafeFreeze outv
{-# NOINLINE dotScalarVector #-}
dotScalarVector :: forall a b. (Num a, Storable a)
@@ -398,7 +397,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ [t| forall n. SNat n
+ -> RS.Array (GHCFromNat n) $ttyp
+ -> RS.Array (GHCFromNat n) $ttyp
+ -> RS.Array (GHCFromNat n) $ttyp |]
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
@@ -412,7 +414,10 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do
c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ [t| forall n. SNat n
+ -> RS.Array (GHCFromNat n) $ttyp
+ -> RS.Array (GHCFromNat n) $ttyp
+ -> RS.Array (GHCFromNat n) $ttyp |]
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
@@ -422,7 +427,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
return $ FunD name [Clause [] (NormalB body) []]])
@@ -432,7 +437,7 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do
let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
return $ FunD name [Clause [] (NormalB body) []]])
@@ -451,11 +456,11 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
sequence [SigD name1 <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ [t| forall n. SNat n -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]
,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
return $ FunD name1 [Clause [] (NormalB body) []]
,SigD namefull <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |]
+ [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> $ttyp |]
,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
return $ FunD namefull [Clause [] (NormalB body) []]
])
@@ -478,7 +483,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ [t| forall n. SNat n -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]
,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
@@ -487,7 +492,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
=> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
- -> (SNat n -> RS.Array n i -> RS.Array n i)
+ -> (SNat n -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i)
intWidBranch1 f32 f64 sn
| finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
| finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
@@ -503,7 +508,7 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
-> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
- -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
+ -> (SNat n -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i)
intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
| finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
| finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
@@ -516,7 +521,7 @@ intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
-- int64
-> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
+ -> (SNat n -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat n) i)
intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
| finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
| finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
@@ -528,7 +533,7 @@ intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32) -- ^ reduction kernel
-- int64
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ reduction kernel
- -> (SNat n -> RS.Array n i -> i)
+ -> (SNat n -> RS.Array (GHCFromNat n) i -> i)
intWidBranchRedFull fsc fred32 fred64 sn
| finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
| finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
@@ -554,26 +559,26 @@ intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt
-> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i)
+ -> (SNat n -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat n) i)
intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
| finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
| finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
| otherwise = error "Unsupported Int width"
class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltSumFull :: SNat n -> RS.Array n a -> a
- numEltProductFull :: SNat n -> RS.Array n a -> a
- numEltMinIndex :: SNat n -> RS.Array n a -> [Int]
- numEltMaxIndex :: SNat n -> RS.Array n a -> [Int]
- numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
+ numEltAdd :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ numEltSub :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ numEltMul :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ numEltNeg :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ numEltAbs :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ numEltSignum :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ numEltSum1Inner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a
+ numEltProduct1Inner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a
+ numEltSumFull :: SNat n -> RS.Array (GHCFromNat n) a -> a
+ numEltProductFull :: SNat n -> RS.Array (GHCFromNat n) a -> a
+ numEltMinIndex :: SNat n -> RS.Array (GHCFromNat n) a -> [Int]
+ numEltMaxIndex :: SNat n -> RS.Array (GHCFromNat n) a -> [Int]
+ numEltDotprodInner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a
instance NumElt Int32 where
numEltAdd = addVectorInt32
@@ -688,29 +693,29 @@ instance NumElt CInt where
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
class FloatElt a where
- floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltDiv :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltPow :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltLogbase :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltRecip :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltExp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltLog :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltSqrt :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltSin :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltCos :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltTan :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltAsin :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltAcos :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltAtan :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltSinh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltCosh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltTanh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltAsinh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltAcosh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltAtanh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltLog1p :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltExpm1 :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltLog1pexp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
+ floatEltLog1mexp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a
instance FloatElt Float where
floatEltDiv = divVectorFloat
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index 331d5e0..e85e67f 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -13,8 +13,6 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Permutation where
import Data.Coerce (coerce)
@@ -24,13 +22,12 @@ import Data.Maybe (fromMaybe)
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
-import Data.Type.Ord
import GHC.TypeError
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
+import Numeric.Natural
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
+import Data.SNat.Peano
-- * Permutations
@@ -46,18 +43,19 @@ deriving instance Show (Perm list)
deriving instance Eq (Perm list)
permRank :: Perm list -> SNat (Rank list)
-permRank PNil = SNat
-permRank (_ `PCons` l) | SNat <- permRank l = SNat
+permRank PNil = SZ
+permRank (_ `PCons` l) = SS (permRank l)
permFromList :: [Int] -> (forall list. Perm list -> r) -> r
permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+permFromList (x : xs) k =
+ withSomeSNat' x $ \sn ->
+ permFromList xs $ \list ->
+ k (sn `PCons` list)
permToList :: Perm list -> [Natural]
permToList PNil = mempty
-permToList (x `PCons` l) = TN.fromSNat x : permToList l
+permToList (x `PCons` l) = fromSNat x : permToList l
permToList' :: Perm list -> [Int]
permToList' = map fromIntegral . permToList
@@ -68,48 +66,47 @@ permToList' = map fromIntegral . permToList
permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r
permCheckPermutation = \p k ->
let n = permRank p
- in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
+ in case (provePerm1 (Proxy @list) n p, provePerm2 SZ n p) of
(Just Refl, Just Refl) -> Just k
_ -> Nothing
where
- lemElemCount :: (0 <= n, Compare n m ~ LT)
- => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount :: (Z <= n, n < m)
+ => proxy n -> proxy m -> Elem n (Count Z m) :~: True
lemElemCount _ _ = unsafeCoerceRefl
- lemCount :: (OrdCond (Compare i n) True False True ~ True)
- => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount :: i < n => proxy i -> proxy n -> Count i n :~: i : Count (S i) n
lemCount _ _ = unsafeCoerceRefl
lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
lemElem _ _ = unsafeCoerceRefl
provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
- -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ -> Maybe (AllElem' is' (Count Z (Rank isTop)) :~: True)
provePerm1 _ _ PNil = Just (Refl)
- provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ provePerm1 p rtop (PCons sn perm)
| Just Refl <- provePerm1 p rtop perm
- = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
- (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ = case (snatCompare SZ sn, snatCompare sn rtop) of
+ (SLT, SLT) | Refl <- lemElemCount sn rtop -> Just Refl
+ (SEQ, SLT) | Refl <- lemElemCount sn rtop -> Just Refl
_ -> Nothing
| otherwise
= Nothing
provePerm2 :: SNat i -> SNat n -> Perm is'
-> Maybe (AllElem' (Count i n) is' :~: True)
- provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
- case cmpNat i n of
- EQI -> Just Refl
- LTI | Refl <- lemCount i n
- , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ provePerm2 = \i n perm ->
+ case snatCompare i n of
+ SEQ -> Just Refl
+ SLT | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SS i) n perm
-> checkElem i perm
| otherwise -> Nothing
- GTI -> error "unreachable"
+ SGT -> error "unreachable"
where
checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
checkElem _ PNil = Nothing
- checkElem i@SNat (PCons k@SNat perm :: Perm is') =
- case sameNat i k of
+ checkElem i (PCons k perm :: Perm is') =
+ case testEquality i k of
Just Refl -> Just Refl
Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
| otherwise -> Nothing
@@ -117,7 +114,7 @@ permCheckPermutation = \p k ->
-- | Utility class for generating permutations from type class information.
class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
-instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = knownNat `PCons` makePerm
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -139,13 +136,13 @@ type AllElem as bs = Assert (AllElem' as bs)
type family Count i n where
Count n n = '[]
- Count i n = i : Count (i + 1) n
+ Count i n = i : Count (S i) n
-type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+type IsPermutation as = (AllElem as (Count Z (Rank as)), AllElem (Count Z (Rank as)) as)
type family Index i sh where
- Index 0 (n : sh) = n
- Index i (_ : sh) = Index (i - 1) sh
+ Index Z (n : sh) = n
+ Index (S i) (_ : sh) = Index i sh
type family Permute is sh where
Permute '[] sh = '[]
@@ -178,9 +175,7 @@ listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
+listxIndex p pT (SS i) (_ ::% sh) = listxIndex p pT i sh
listxIndex _ _ _ ZX = error "Index into empty shape"
listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
@@ -199,7 +194,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
@@ -236,23 +231,23 @@ permInverse = \perm k ->
where
toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
toHList [] k = k PNil
- toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
+ toHList (n : ns) k = toHList ns $ \l -> withSomeSNat n $ \sn -> k (PCons sn l)
provePermInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh =
- ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+ testEquality (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
MapSucc '[] = '[]
- MapSucc (i : is) = i + 1 : MapSucc is
+ MapSucc (i : is) = S i : MapSucc is
-permShift1 :: Perm l -> Perm (0 : MapSucc l)
-permShift1 = (SNat @0 `PCons`) . permMapSucc
+permShift1 :: Perm l -> Perm (Z : MapSucc l)
+permShift1 = (SZ `PCons`) . permMapSucc
where
permMapSucc :: Perm l -> Perm (MapSucc l)
permMapSucc PNil = PNil
- permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+ permMapSucc (i `PCons` ns) = SS i `PCons` permMapSucc ns
-- * Lemmas
@@ -266,8 +261,3 @@ lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
lemRankDropLen ZKX PNil = Refl
lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
lemRankDropLen (_ :!% _) PNil = Refl
-lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
-
-lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
- -> Index (i + 1) (a : l) :~: Index i l
-lemIndexSucc _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index e5f8b67..bed812d 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -18,8 +18,6 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Shape where
import Control.DeepSeq (NFData(..))
@@ -35,16 +33,16 @@ import GHC.Exts (withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
-import GHC.TypeLits
import Data.Array.Mixed.Types
+import Data.SNat.Peano
-- | The length of a type-level list. If the argument is a shape, then the
-- result is the rank of that shape.
type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = Rank sh + 1
+ Rank '[] = Z
+ Rank (_ : sh) = S (Rank sh)
-- * Mixed lists
@@ -91,8 +89,8 @@ listxLength :: ListX sh f -> Int
listxLength = getSum . listxFold (\_ -> Sum 1)
listxRank :: ListX sh f -> SNat (Rank sh)
-listxRank ZX = SNat
-listxRank (_ ::% l) | SNat <- listxRank l = SNat
+listxRank ZX = SZ
+listxRank (_ ::% l) = SS (listxRank l)
listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
@@ -255,7 +253,7 @@ type family AddMaybe n m where
smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
-smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatAdd n m)
-- | This is a newtype over 'ListX'.
@@ -288,7 +286,7 @@ instance Functor (ShX sh) where
instance NFData i => NFData (ShX sh i) where
rnf (ShX ZX) = ()
rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+ rnf (ShX (SKnown n ::% l)) = rnf n `seq` rnf (ShX l)
shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l
@@ -300,8 +298,8 @@ shxRank (ShX list) = listxRank list
-- dimensions) are the same.
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
-shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
- | Just Refl <- sameNat n m
+shxEqual (SKnown n :$% sh) (SKnown m :$% sh')
+ | Just Refl <- testEquality n m
, Just Refl <- shxEqual sh sh'
= Just Refl
shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
@@ -422,19 +420,6 @@ instance TestEquality StaticShX where
ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l
--- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
--- class of the @some@ package.
-ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
-ssxGeq ZKX ZKX = Just Refl
-ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
- | Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq _ _ = Nothing
-
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
@@ -481,7 +466,7 @@ ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
type KnownShX :: [Maybe Nat] -> Constraint
class KnownShX sh where knownShX :: StaticShX sh
instance KnownShX '[] where knownShX = ZKX
-instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown knownNat :!% knownShX
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
@@ -490,7 +475,7 @@ withKnownShX sh = withDict @(KnownShX sh) sh
-- * Flattening
-type Flatten sh = Flatten' 1 sh
+type Flatten sh = Flatten' (S Z) sh
type family Flatten' acc sh where
Flatten' acc '[] = Just acc
@@ -499,7 +484,7 @@ type family Flatten' acc sh where
-- This function is currently unused
ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
-ssxFlatten = go (SNat @1)
+ssxFlatten = go (mkSNat @1)
where
go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
go acc ZKX = SKnown acc
@@ -507,7 +492,7 @@ ssxFlatten = go (SNat @1)
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
-shxFlatten = go (SNat @1)
+shxFlatten = go (mkSNat @1)
where
go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
go acc ZSX = SKnown acc
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index 13675d0..319246b 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -1,27 +1,14 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Types (
-- * Reified evidence of a type class
Dict(..),
- -- * Type-level naturals
- pattern SZ, pattern SS,
- fromSNat', sameNat',
- snatPlus, snatMinus, snatMul,
- snatSucc,
-
-- * Type-level lists
type (++),
Replicate,
@@ -36,67 +23,25 @@ module Data.Array.Mixed.Types (
) where
import Data.Type.Equality
-import Data.Proxy
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
import Unsafe.Coerce qualified
+import Data.SNat.Peano
+
-- | Evidence for the constraint @c a@.
data Dict c a where
Dict :: c a => Dict c a
-fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
-
-sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
-sameNat' n@SNat m@SNat = sameNat n m
-
-pattern SZ :: () => (n ~ 0) => SNat n
-pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
- where SZ = SNat
-
-pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
-pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
- where SS = snatSucc
-
-{-# COMPLETE SZ, SS #-}
-
-snatSucc :: SNat n -> SNat (n + 1)
-snatSucc SNat = SNat
-
-data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
-snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
-snatPred snp1 =
- withKnownNat snp1 $
- case cmpNat (Proxy @1) (Proxy @np1) of
- LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- GTI -> Nothing
-
--- This should be a function in base
-snatPlus :: SNat n -> SNat m -> SNat (n + m)
-snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
-
--- This should be a function in base
-snatMinus :: SNat n -> SNat m -> SNat (n - m)
-snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce
-
--- This should be a function in base
-snatMul :: SNat n -> SNat m -> SNat (n * m)
-snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
-
-
-- | Type-level list append.
type family l1 ++ l2 where
'[] ++ l2 = l2
(x : xs) ++ l2 = x : xs ++ l2
type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
+ Replicate Z a = '[]
+ Replicate (S n) a = a : Replicate n a
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (S n) a
lemReplicateSucc = unsafeCoerceRefl
type family MapJust l where
diff --git a/src/Data/SNat/Peano.hs b/src/Data/SNat/Peano.hs
new file mode 100644
index 0000000..a5109fa
--- /dev/null
+++ b/src/Data/SNat/Peano.hs
@@ -0,0 +1,232 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.SNat.Peano (
+ -- * Singleton naturals
+ Nat(..),
+ SNat(SZ, SS),
+ mkSNat, mkSNatFromGHC, NatFromGHC,
+ withSomeSNat, withSomeSNat',
+ fromSNat, fromSNat',
+
+ -- * Computing with 'SNat' values
+ type (+), snatAdd,
+ type (-), snatSub,
+ type (*), snatMul,
+ Compare, SOrdering(..), snatCompare, OrdCond, type (<=), type (<), type (>=), type (>),
+
+ -- * 'KnownNat'
+ KnownNat(..),
+
+ -- * Interoperate with GHC naturals
+ pattern SNat', recoverGHC, GHCFromNat, lemGHCNatGHC, lemNatGHCNat,
+) where
+
+import Control.DeepSeq
+import Data.Proxy
+import Data.Type.Equality
+import Numeric.Natural
+import qualified GHC.TypeNats as GHC
+import Unsafe.Coerce
+
+
+-- | Type-level Peano naturals.
+type data Nat = Z | S Nat
+
+-- | A singleton for 'Nat'. The representation is just a 'Natural', not an
+-- actual Peano natural / linked list.
+newtype SNat (n :: Nat) = MkSNatUnsafe Natural
+
+instance Show (SNat n) where
+ showsPrec d (MkSNatUnsafe n) = showParen (d > 10) $
+ showString ("mkSNat @" ++ show n)
+
+-- these are vacuous because it's a singleton
+instance Eq (SNat n) where _ == _ = True
+instance Ord (SNat n) where compare _ _ = EQ
+
+instance NFData (SNat n) where
+ rnf (MkSNatUnsafe n) = rnf n
+
+instance TestEquality SNat where
+ testEquality (MkSNatUnsafe n) (MkSNatUnsafe m)
+ | n == m = Just unsafeCoerceRefl
+ | otherwise = Nothing
+
+-- | The zero natural. Read this as: 'SZ :: SNat Z'
+pattern SZ :: forall n. () => n ~ Z => SNat n
+pattern SZ <- ((\(MkSNatUnsafe k) -> (# k, unsafeCoerceRefl @n @Z #))
+ -> (# 0, Refl #))
+ where SZ = MkSNatUnsafe 0
+
+-- | /n/ plus one. Read this as: 'SS :: SNat n -> SNat (S n)'
+pattern SS :: forall n. () => forall predn. S predn ~ n => SNat predn -> SNat n
+pattern SS n <- (mkPredecessor -> Just (Predecessor n))
+ where SS (MkSNatUnsafe n) = MkSNatUnsafe (n + 1)
+-- A little experiment showed that mkPredecessor inlines sufficiently that no
+-- Predecessor object every gets allocated. Let's hope this is also true in
+-- more practical situations.
+
+{-# COMPLETE SZ, SS #-}
+
+data Predecessor n = forall predn. S predn ~ n => Predecessor (SNat predn)
+
+mkPredecessor :: forall n. SNat n -> Maybe (Predecessor n)
+mkPredecessor (MkSNatUnsafe 0) = Nothing
+mkPredecessor (MkSNatUnsafe k) = Just (yolo (MkSNatUnsafe (k-1)))
+ where
+ yolo :: forall m predn. SNat predn -> Predecessor m
+ yolo n | Refl <- unsafeCoerceRefl @(S predn) @m = Predecessor n
+
+-- | Convert a GHC type-level 'GHC.Nat' to a type-level Peano natural. Because
+-- this type family performs induction on a GHC 'GHC.Nat', which only works
+-- sensibly if the 'GHC.Nat' is monomorphic, using this on a bare type variable
+-- will probably be unsuccessful.
+type family NatFromGHC ghcn where
+ NatFromGHC 0 = Z
+ NatFromGHC n = S (NatFromGHC (n GHC.- 1))
+
+-- | Convenience function to create an 'SNat'. Use with @-XDataKinds@ and
+-- @-XTypeApplications@ like:
+--
+-- >>> mkSNat @5
+--
+-- The 'GHC.KnownNat' constraint is automatically satisfied for any statically
+-- known number. To construct an 'SNat' dynamically, you probably want
+-- 'mkSNatFromGHC', or perhaps iterated 'SS'.
+mkSNat :: forall ghcn. GHC.KnownNat ghcn => SNat (NatFromGHC ghcn)
+mkSNat = MkSNatUnsafe (GHC.natVal (Proxy @ghcn))
+
+-- | Convert a GHC 'GHC.SNat' to an 'SNat'. You can convert back using
+-- the 'SNat'' pattern synonym, or more manually using 'recoverGHC'.
+mkSNatFromGHC :: GHC.SNat ghcn -> SNat (NatFromGHC ghcn)
+mkSNatFromGHC sn@GHC.SNat = MkSNatUnsafe (GHC.natVal sn)
+
+-- | Dynamically create an 'SNat' from an untyped 'Natural'.
+withSomeSNat :: Natural -> (forall n. SNat n -> r) -> r
+withSomeSNat n k = k (MkSNatUnsafe n)
+
+-- | Dynamically create an 'SNat' from an untyped 'Int'. Throws an exception if
+-- the argument is negative.
+withSomeSNat' :: Int -> (forall n. SNat n -> r) -> r
+withSomeSNat' n k
+ | n < 0 = error $ "withSomeSNat': " ++ show n ++ " is negative"
+ | otherwise = k (MkSNatUnsafe (fromIntegral n))
+
+-- | Get the untyped 'Natural' corresponding to the 'SNat'.
+fromSNat :: SNat n -> Natural
+fromSNat (MkSNatUnsafe n) = n
+
+-- | Unsafe! If @n@ is out of range for @Int@, this will simply wrap, not throw
+-- an error!
+fromSNat' :: SNat n -> Int
+fromSNat' (MkSNatUnsafe n) = fromIntegral n
+
+-- | Convert a type-level Peano natural to a GHC type-level 'GHC.Nat'.
+type family GHCFromNat n where
+ GHCFromNat Z = 0
+ GHCFromNat (S n) = 1 GHC.+ GHCFromNat n
+
+-- | Convert an 'SNat' back to a GHC 'GHC.SNat'. If you use 'recoverGHC' after
+-- 'mkSNatFromGHC', you will end up with an 'GHC.SNat (GHCFromNat (NatFromGHC
+-- ghcn))'; use 'lemGHCToGHC' to rewrite that back to 'GHC.SNat ghcn'.
+recoverGHC :: forall n. SNat n -> GHC.SNat (GHCFromNat n)
+recoverGHC (MkSNatUnsafe n) =
+ GHC.withSomeSNat n $ \(ghcn :: GHC.SNat m) ->
+ unsafeCoerce @(GHC.SNat m) @(GHC.SNat (GHCFromNat n)) ghcn
+
+-- | 'GHCFromNat' and 'NatFromGHC' are inverses (first half).
+lemGHCNatGHC :: GHCFromNat (NatFromGHC ghcn) :~: ghcn
+lemGHCNatGHC = unsafeCoerceRefl
+
+-- | 'GHCFromNat' and 'NatFromGHC' are inverses (second half).
+lemNatGHCNat :: NatFromGHC (GHCFromNat n) :~: n
+lemNatGHCNat = unsafeCoerceRefl
+
+pattern SNat' :: forall n. () => GHC.KnownNat (GHCFromNat n) => SNat n
+pattern SNat' <- (recoverGHC -> GHC.SNat)
+ where SNat' = case lemNatGHCNat @n of Refl -> mkSNat @(GHCFromNat n)
+{-# COMPLETE SNat' #-}
+
+-- | Add type-level Peano naturals.
+type family n + m where
+ Z + m = m
+ S n + m = S (n + m)
+
+-- | Add 'SNat's.
+snatAdd :: SNat n -> SNat m -> SNat (n + m)
+snatAdd (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n + m)
+
+-- | Subtract type-level Peano naturals. Does not reduce if the result would be negative.
+type family n - m where
+ n - Z = n
+ S n - S m = n - m
+
+-- | Subtract 'SNat's. Returns 'Nothing' if the result would be negative.
+snatSub :: SNat n -> SNat m -> Maybe (SNat (n - m))
+snatSub (MkSNatUnsafe n) (MkSNatUnsafe m)
+ | n >= m = Just (MkSNatUnsafe (n - m))
+ | otherwise = Nothing
+
+-- | Multiply type-level Peano naturals.
+type family n * m where
+ Z * m = Z
+ S n * m = m + n * m
+
+-- | Multiply 'SNat's.
+snatMul :: SNat n -> SNat m -> SNat (n * m)
+snatMul (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n * m)
+
+type family Compare n m where
+ Compare Z Z = EQ
+ Compare Z (S m) = LT
+ Compare (S n) Z = GT
+ Compare (S n) (S m) = Compare n m
+
+data SOrdering n m where
+ SLT :: Compare n m ~ LT => SOrdering n m
+ SEQ :: Compare n n ~ EQ => SOrdering n n
+ SGT :: Compare n m ~ GT => SOrdering n m
+
+snatCompare :: SNat n -> SNat m -> SOrdering n m
+snatCompare (MkSNatUnsafe @n n) (MkSNatUnsafe @m m) = case compare n m of
+ LT | Refl <- unsafeCoerceRefl @(Compare n m) @LT ->
+ SLT
+ EQ | Refl <- unsafeCoerceRefl @n @m
+ , Refl <- unsafeCoerceRefl @(Compare n n) @EQ ->
+ SEQ
+ GT | Refl <- unsafeCoerceRefl @(Compare n m) @GT ->
+ SGT
+
+type family OrdCond ord lt eq gt where
+ OrdCond LT lt eq gt = lt
+ OrdCond EQ lt eq gt = eq
+ OrdCond GT lt eq gt = gt
+
+type n <= m = OrdCond (Compare n m) True True False ~ True
+type n < m = Compare n m ~ LT
+type n >= m = OrdCond (Compare n m) False True True ~ True
+type n > m = Compare n m ~ GT
+
+-- | Pass an 'SNat' implicitly, in a constraint.
+class KnownNat n where knownNat :: SNat n
+instance KnownNat Z where knownNat = SZ
+instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+
+unsafeCoerceRefl :: forall a b. a :~: b
+unsafeCoerceRefl = unsafeCoerce Refl