diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-11 00:11:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-15 11:06:40 +0100 |
commit | e6c20868375d2b7f6b31808844e1b48f78bca069 (patch) | |
tree | 5e3c3efa5c61eb11a28b486bccbbcac823a36614 | |
parent | c705bb4cf76d2e80f3e9ed900f901b697b378f79 (diff) |
WIP half-peano SNatspeano-snat
-rw-r--r-- | ox-arrays.cabal | 1 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 129 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 88 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 41 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 67 | ||||
-rw-r--r-- | src/Data/SNat/Peano.hs | 232 |
6 files changed, 358 insertions, 200 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 1253956..376107b 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -35,6 +35,7 @@ library Data.Array.Nested.Internal.Ranked Data.Array.Nested.Internal.Shape Data.Array.Nested.Internal.Shaped + Data.SNat.Peano if flag(trace-wrappers) exposed-modules: diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a24efd6..dd4d8e7 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -24,14 +24,13 @@ import Foreign.C.Types import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr import Foreign.Storable (Storable(sizeOf), peek, poke) -import GHC.TypeLits import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH import System.IO.Unsafe +import Data.SNat.Peano import Data.Array.Mixed.Internal.Arith.Foreign import Data.Array.Mixed.Internal.Arith.Lists -import Data.Array.Mixed.Types (fromSNat') -- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition @@ -41,8 +40,8 @@ import Data.Array.Mixed.Types (fromSNat') liftVEltwise1 :: (Storable a, Storable b) => SNat n -> (VS.Vector a -> VS.Vector b) - -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) b +liftVEltwise1 SNat' f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | Just (blockOff, blockSz) <- stridesDense sh offset strides = let vec' = f (VS.slice blockOff blockSz vec) in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) @@ -52,8 +51,8 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) liftVEltwise2 :: (Storable a, Storable b, Storable c) => SNat n -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) - -> RS.Array n a -> RS.Array n b -> RS.Array n c -liftVEltwise2 SNat f + -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) b -> RS.Array (GHCFromNat n) c +liftVEltwise2 SNat' f arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 @@ -172,8 +171,8 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a) -> (Ptr a -> Ptr b) -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel - -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) + -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a +vectorRedInnerOp sn@SNat' valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) | null sh = error "unreachable" | last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0]) | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty)) @@ -210,7 +209,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR) - TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> + TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(TypeNats.SNat :: TypeNats.SNat lenFm1) -> RS.stretch (init sh) -- replicate to original shape . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated . RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions @@ -226,7 +225,7 @@ vectorRedFullOp :: forall a b n. (Num a, Storable a) -> (b -> a) -> (Ptr a -> Ptr b) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel - -> RS.Array n a -> a + -> RS.Array (GHCFromNat n) a -> a vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec))) | null sh = vec VS.! offset -- 0D array has one element | any (<= 0) sh = 0 @@ -309,12 +308,12 @@ vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) => SNat n -> (a -> b) -> (Ptr a -> Ptr b) - -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication + -> (SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a) -- ^ elementwise multiplication -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel - -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a -vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner + -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a +vectorDotprodInnerOp sn@SNat' valconv ptrconv fmul fscale fred fdotinner arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) | null sh1 || null sh2 = error "unreachable" @@ -344,7 +343,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) - RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv + RS.fromVector @_ @(GHCFromNat n) (init sh1) <$> VS.unsafeFreeze outv {-# NOINLINE dotScalarVector #-} dotScalarVector :: forall a b. (Num a, Storable a) @@ -398,7 +397,10 @@ $(fmap concat . forM typesList $ \arithtype -> do c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + [t| forall n. SNat n + -> RS.Array (GHCFromNat n) $ttyp + -> RS.Array (GHCFromNat n) $ttyp + -> RS.Array (GHCFromNat n) $ttyp |] ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] return $ FunD name [Clause [] (NormalB body) []]]) @@ -412,7 +414,10 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + [t| forall n. SNat n + -> RS.Array (GHCFromNat n) $ttyp + -> RS.Array (GHCFromNat n) $ttyp + -> RS.Array (GHCFromNat n) $ttyp |] ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] return $ FunD name [Clause [] (NormalB body) []]]) @@ -422,7 +427,7 @@ $(fmap concat . forM typesList $ \arithtype -> do let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> RS.Array (GHCFromNat n) $ttyp |] ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] return $ FunD name [Clause [] (NormalB body) []]]) @@ -432,7 +437,7 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> RS.Array (GHCFromNat n) $ttyp |] ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] return $ FunD name [Clause [] (NormalB body) []]]) @@ -451,11 +456,11 @@ $(fmap concat . forM typesList $ \arithtype -> do c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) sequence [SigD name1 <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + [t| forall n. SNat n -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat n) $ttyp |] ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |] return $ FunD name1 [Clause [] (NormalB body) []] ,SigD namefull <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |] + [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> $ttyp |] ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |] return $ FunD namefull [Clause [] (NormalB body) []] ]) @@ -478,7 +483,7 @@ $(fmap concat . forM typesList $ \arithtype -> do c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM))) sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + [t| forall n. SNat n -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat n) $ttyp |] ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) @@ -487,7 +492,7 @@ $(fmap concat . forM typesList $ \arithtype -> do intWidBranch1 :: forall i n. (FiniteBits i, Storable i) => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) - -> (SNat n -> RS.Array n i -> RS.Array n i) + -> (SNat n -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i) intWidBranch1 f32 f64 sn | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) @@ -503,7 +508,7 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv - -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) + -> (SNat n -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i) intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) @@ -516,7 +521,7 @@ intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) -- int64 -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) + -> (SNat n -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat n) i) intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 @@ -528,7 +533,7 @@ intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32) -- ^ reduction kernel -- int64 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ reduction kernel - -> (SNat n -> RS.Array n i -> i) + -> (SNat n -> RS.Array (GHCFromNat n) i -> i) intWidBranchRedFull fsc fred32 fred64 sn | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32 | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 @@ -554,26 +559,26 @@ intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i) + -> (SNat n -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat n) i) intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64 | otherwise = error "Unsupported Int width" class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltSumFull :: SNat n -> RS.Array n a -> a - numEltProductFull :: SNat n -> RS.Array n a -> a - numEltMinIndex :: SNat n -> RS.Array n a -> [Int] - numEltMaxIndex :: SNat n -> RS.Array n a -> [Int] - numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a + numEltAdd :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + numEltSub :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + numEltMul :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + numEltNeg :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + numEltAbs :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + numEltSignum :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + numEltSum1Inner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a + numEltProduct1Inner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a + numEltSumFull :: SNat n -> RS.Array (GHCFromNat n) a -> a + numEltProductFull :: SNat n -> RS.Array (GHCFromNat n) a -> a + numEltMinIndex :: SNat n -> RS.Array (GHCFromNat n) a -> [Int] + numEltMaxIndex :: SNat n -> RS.Array (GHCFromNat n) a -> [Int] + numEltDotprodInner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a instance NumElt Int32 where numEltAdd = addVectorInt32 @@ -688,29 +693,29 @@ instance NumElt CInt where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 class FloatElt a where - floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltDiv :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltPow :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltLogbase :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltRecip :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltExp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltLog :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltSqrt :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltSin :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltCos :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltTan :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltAsin :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltAcos :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltAtan :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltSinh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltCosh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltTanh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltAsinh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltAcosh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltAtanh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltLog1p :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltExpm1 :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltLog1pexp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a + floatEltLog1mexp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a instance FloatElt Float where floatEltDiv = divVectorFloat diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 331d5e0..e85e67f 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -13,8 +13,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Permutation where import Data.Coerce (coerce) @@ -24,13 +22,12 @@ import Data.Maybe (fromMaybe) import Data.Proxy import Data.Type.Bool import Data.Type.Equality -import Data.Type.Ord import GHC.TypeError -import GHC.TypeLits -import GHC.TypeNats qualified as TN +import Numeric.Natural import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +import Data.SNat.Peano -- * Permutations @@ -46,18 +43,19 @@ deriving instance Show (Perm list) deriving instance Eq (Perm list) permRank :: Perm list -> SNat (Rank list) -permRank PNil = SNat -permRank (_ `PCons` l) | SNat <- permRank l = SNat +permRank PNil = SZ +permRank (_ `PCons` l) = SS (permRank l) permFromList :: [Int] -> (forall list. Perm list -> r) -> r permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x +permFromList (x : xs) k = + withSomeSNat' x $ \sn -> + permFromList xs $ \list -> + k (sn `PCons` list) permToList :: Perm list -> [Natural] permToList PNil = mempty -permToList (x `PCons` l) = TN.fromSNat x : permToList l +permToList (x `PCons` l) = fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList @@ -68,48 +66,47 @@ permToList' = map fromIntegral . permToList permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r permCheckPermutation = \p k -> let n = permRank p - in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of + in case (provePerm1 (Proxy @list) n p, provePerm2 SZ n p) of (Just Refl, Just Refl) -> Just k _ -> Nothing where - lemElemCount :: (0 <= n, Compare n m ~ LT) - => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount :: (Z <= n, n < m) + => proxy n -> proxy m -> Elem n (Count Z m) :~: True lemElemCount _ _ = unsafeCoerceRefl - lemCount :: (OrdCond (Compare i n) True False True ~ True) - => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount :: i < n => proxy i -> proxy n -> Count i n :~: i : Count (S i) n lemCount _ _ = unsafeCoerceRefl lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True lemElem _ _ = unsafeCoerceRefl provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + -> Maybe (AllElem' is' (Count Z (Rank isTop)) :~: True) provePerm1 _ _ PNil = Just (Refl) - provePerm1 p rtop@SNat (PCons sn@SNat perm) + provePerm1 p rtop (PCons sn perm) | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + = case (snatCompare SZ sn, snatCompare sn rtop) of + (SLT, SLT) | Refl <- lemElemCount sn rtop -> Just Refl + (SEQ, SLT) | Refl <- lemElemCount sn rtop -> Just Refl _ -> Nothing | otherwise = Nothing provePerm2 :: SNat i -> SNat n -> Perm is' -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + provePerm2 = \i n perm -> + case snatCompare i n of + SEQ -> Just Refl + SLT | Refl <- lemCount i n + , Just Refl <- provePerm2 (SS i) n perm -> checkElem i perm | otherwise -> Nothing - GTI -> error "unreachable" + SGT -> error "unreachable" where checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) checkElem _ PNil = Nothing - checkElem i@SNat (PCons k@SNat perm :: Perm is') = - case sameNat i k of + checkElem i (PCons k perm :: Perm is') = + case testEquality i k of Just Refl -> Just Refl Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl | otherwise -> Nothing @@ -117,7 +114,7 @@ permCheckPermutation = \p k -> -- | Utility class for generating permutations from type class information. class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil -instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = knownNat `PCons` makePerm -- | Untyped permutations for ranked arrays type PermR = [Int] @@ -139,13 +136,13 @@ type AllElem as bs = Assert (AllElem' as bs) type family Count i n where Count n n = '[] - Count i n = i : Count (i + 1) n + Count i n = i : Count (S i) n -type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) +type IsPermutation as = (AllElem as (Count Z (Rank as)), AllElem (Count Z (Rank as)) as) type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh + Index Z (n : sh) = n + Index (S i) (_ : sh) = Index i sh type family Permute is sh where Permute '[] sh = '[] @@ -178,9 +175,7 @@ listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh +listxIndex p pT (SS i) (_ ::% sh) = listxIndex p pT i sh listxIndex _ _ _ ZX = error "Index into empty shape" listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f @@ -199,7 +194,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) +ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) @@ -236,23 +231,23 @@ permInverse = \perm k -> where toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r toHList [] k = k PNil - toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + toHList (n : ns) k = toHList ns $ \l -> withSomeSNat n $ \sn -> k (PCons sn l) provePermInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = - ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh + testEquality (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where MapSucc '[] = '[] - MapSucc (i : is) = i + 1 : MapSucc is + MapSucc (i : is) = S i : MapSucc is -permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc +permShift1 :: Perm l -> Perm (Z : MapSucc l) +permShift1 = (SZ `PCons`) . permMapSucc where permMapSucc :: Perm l -> Perm (MapSucc l) permMapSucc PNil = PNil - permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + permMapSucc (i `PCons` ns) = SS i `PCons` permMapSucc ns -- * Lemmas @@ -266,8 +261,3 @@ lemRankDropLen :: forall is sh. (Rank is <= Rank sh) lemRankDropLen ZKX PNil = Refl lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!% _) PNil = Refl -lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l - -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index e5f8b67..bed812d 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -18,8 +18,6 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Shape where import Control.DeepSeq (NFData(..)) @@ -35,16 +33,16 @@ import GHC.Exts (withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList -import GHC.TypeLits import Data.Array.Mixed.Types +import Data.SNat.Peano -- | The length of a type-level list. If the argument is a shape, then the -- result is the rank of that shape. type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 + Rank '[] = Z + Rank (_ : sh) = S (Rank sh) -- * Mixed lists @@ -91,8 +89,8 @@ listxLength :: ListX sh f -> Int listxLength = getSum . listxFold (\_ -> Sum 1) listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat +listxRank ZX = SZ +listxRank (_ ::% l) = SS (listxRank l) listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS listxShow f l = showString "[" . go "" l . showString "]" @@ -255,7 +253,7 @@ type family AddMaybe n m where smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatAdd n m) -- | This is a newtype over 'ListX'. @@ -288,7 +286,7 @@ instance Functor (ShX sh) where instance NFData i => NFData (ShX sh i) where rnf (ShX ZX) = () rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + rnf (ShX (SKnown n ::% l)) = rnf n `seq` rnf (ShX l) shxLength :: ShX sh i -> Int shxLength (ShX l) = listxLength l @@ -300,8 +298,8 @@ shxRank (ShX list) = listxRank list -- dimensions) are the same. shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m +shxEqual (SKnown n :$% sh) (SKnown m :$% sh') + | Just Refl <- testEquality n m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') @@ -422,19 +420,6 @@ instance TestEquality StaticShX where ssxLength :: StaticShX sh -> Int ssxLength (StaticShX l) = listxLength l --- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ --- class of the @some@ package. -ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -ssxGeq ZKX ZKX = Just Refl -ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') - | Just Refl <- sameNat n m - , Just Refl <- ssxGeq sh sh' - = Just Refl -ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') - | Just Refl <- ssxGeq sh sh' - = Just Refl -ssxGeq _ _ = Nothing - ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' @@ -481,7 +466,7 @@ ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 type KnownShX :: [Maybe Nat] -> Constraint class KnownShX sh where knownShX :: StaticShX sh instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown knownNat :!% knownShX instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r @@ -490,7 +475,7 @@ withKnownShX sh = withDict @(KnownShX sh) sh -- * Flattening -type Flatten sh = Flatten' 1 sh +type Flatten sh = Flatten' (S Z) sh type family Flatten' acc sh where Flatten' acc '[] = Just acc @@ -499,7 +484,7 @@ type family Flatten' acc sh where -- This function is currently unused ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) -ssxFlatten = go (SNat @1) +ssxFlatten = go (mkSNat @1) where go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) go acc ZKX = SKnown acc @@ -507,7 +492,7 @@ ssxFlatten = go (SNat @1) go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) -shxFlatten = go (SNat @1) +shxFlatten = go (mkSNat @1) where go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) go acc ZSX = SKnown acc diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index 13675d0..319246b 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -1,27 +1,14 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Types ( -- * Reified evidence of a type class Dict(..), - -- * Type-level naturals - pattern SZ, pattern SS, - fromSNat', sameNat', - snatPlus, snatMinus, snatMul, - snatSucc, - -- * Type-level lists type (++), Replicate, @@ -36,67 +23,25 @@ module Data.Array.Mixed.Types ( ) where import Data.Type.Equality -import Data.Proxy -import GHC.TypeLits -import GHC.TypeNats qualified as TN import Unsafe.Coerce qualified +import Data.SNat.Peano + -- | Evidence for the constraint @c a@. data Dict c a where Dict :: c a => Dict c a -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) -sameNat' n@SNat m@SNat = sameNat n m - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - --- This should be a function in base -snatPlus :: SNat n -> SNat m -> SNat (n + m) -snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMinus :: SNat n -> SNat m -> SNat (n - m) -snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMul :: SNat n -> SNat m -> SNat (n * m) -snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - - -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a + Replicate Z a = '[] + Replicate (S n) a = a : Replicate n a -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (S n) a lemReplicateSucc = unsafeCoerceRefl type family MapJust l where diff --git a/src/Data/SNat/Peano.hs b/src/Data/SNat/Peano.hs new file mode 100644 index 0000000..a5109fa --- /dev/null +++ b/src/Data/SNat/Peano.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.SNat.Peano ( + -- * Singleton naturals + Nat(..), + SNat(SZ, SS), + mkSNat, mkSNatFromGHC, NatFromGHC, + withSomeSNat, withSomeSNat', + fromSNat, fromSNat', + + -- * Computing with 'SNat' values + type (+), snatAdd, + type (-), snatSub, + type (*), snatMul, + Compare, SOrdering(..), snatCompare, OrdCond, type (<=), type (<), type (>=), type (>), + + -- * 'KnownNat' + KnownNat(..), + + -- * Interoperate with GHC naturals + pattern SNat', recoverGHC, GHCFromNat, lemGHCNatGHC, lemNatGHCNat, +) where + +import Control.DeepSeq +import Data.Proxy +import Data.Type.Equality +import Numeric.Natural +import qualified GHC.TypeNats as GHC +import Unsafe.Coerce + + +-- | Type-level Peano naturals. +type data Nat = Z | S Nat + +-- | A singleton for 'Nat'. The representation is just a 'Natural', not an +-- actual Peano natural / linked list. +newtype SNat (n :: Nat) = MkSNatUnsafe Natural + +instance Show (SNat n) where + showsPrec d (MkSNatUnsafe n) = showParen (d > 10) $ + showString ("mkSNat @" ++ show n) + +-- these are vacuous because it's a singleton +instance Eq (SNat n) where _ == _ = True +instance Ord (SNat n) where compare _ _ = EQ + +instance NFData (SNat n) where + rnf (MkSNatUnsafe n) = rnf n + +instance TestEquality SNat where + testEquality (MkSNatUnsafe n) (MkSNatUnsafe m) + | n == m = Just unsafeCoerceRefl + | otherwise = Nothing + +-- | The zero natural. Read this as: 'SZ :: SNat Z' +pattern SZ :: forall n. () => n ~ Z => SNat n +pattern SZ <- ((\(MkSNatUnsafe k) -> (# k, unsafeCoerceRefl @n @Z #)) + -> (# 0, Refl #)) + where SZ = MkSNatUnsafe 0 + +-- | /n/ plus one. Read this as: 'SS :: SNat n -> SNat (S n)' +pattern SS :: forall n. () => forall predn. S predn ~ n => SNat predn -> SNat n +pattern SS n <- (mkPredecessor -> Just (Predecessor n)) + where SS (MkSNatUnsafe n) = MkSNatUnsafe (n + 1) +-- A little experiment showed that mkPredecessor inlines sufficiently that no +-- Predecessor object every gets allocated. Let's hope this is also true in +-- more practical situations. + +{-# COMPLETE SZ, SS #-} + +data Predecessor n = forall predn. S predn ~ n => Predecessor (SNat predn) + +mkPredecessor :: forall n. SNat n -> Maybe (Predecessor n) +mkPredecessor (MkSNatUnsafe 0) = Nothing +mkPredecessor (MkSNatUnsafe k) = Just (yolo (MkSNatUnsafe (k-1))) + where + yolo :: forall m predn. SNat predn -> Predecessor m + yolo n | Refl <- unsafeCoerceRefl @(S predn) @m = Predecessor n + +-- | Convert a GHC type-level 'GHC.Nat' to a type-level Peano natural. Because +-- this type family performs induction on a GHC 'GHC.Nat', which only works +-- sensibly if the 'GHC.Nat' is monomorphic, using this on a bare type variable +-- will probably be unsuccessful. +type family NatFromGHC ghcn where + NatFromGHC 0 = Z + NatFromGHC n = S (NatFromGHC (n GHC.- 1)) + +-- | Convenience function to create an 'SNat'. Use with @-XDataKinds@ and +-- @-XTypeApplications@ like: +-- +-- >>> mkSNat @5 +-- +-- The 'GHC.KnownNat' constraint is automatically satisfied for any statically +-- known number. To construct an 'SNat' dynamically, you probably want +-- 'mkSNatFromGHC', or perhaps iterated 'SS'. +mkSNat :: forall ghcn. GHC.KnownNat ghcn => SNat (NatFromGHC ghcn) +mkSNat = MkSNatUnsafe (GHC.natVal (Proxy @ghcn)) + +-- | Convert a GHC 'GHC.SNat' to an 'SNat'. You can convert back using +-- the 'SNat'' pattern synonym, or more manually using 'recoverGHC'. +mkSNatFromGHC :: GHC.SNat ghcn -> SNat (NatFromGHC ghcn) +mkSNatFromGHC sn@GHC.SNat = MkSNatUnsafe (GHC.natVal sn) + +-- | Dynamically create an 'SNat' from an untyped 'Natural'. +withSomeSNat :: Natural -> (forall n. SNat n -> r) -> r +withSomeSNat n k = k (MkSNatUnsafe n) + +-- | Dynamically create an 'SNat' from an untyped 'Int'. Throws an exception if +-- the argument is negative. +withSomeSNat' :: Int -> (forall n. SNat n -> r) -> r +withSomeSNat' n k + | n < 0 = error $ "withSomeSNat': " ++ show n ++ " is negative" + | otherwise = k (MkSNatUnsafe (fromIntegral n)) + +-- | Get the untyped 'Natural' corresponding to the 'SNat'. +fromSNat :: SNat n -> Natural +fromSNat (MkSNatUnsafe n) = n + +-- | Unsafe! If @n@ is out of range for @Int@, this will simply wrap, not throw +-- an error! +fromSNat' :: SNat n -> Int +fromSNat' (MkSNatUnsafe n) = fromIntegral n + +-- | Convert a type-level Peano natural to a GHC type-level 'GHC.Nat'. +type family GHCFromNat n where + GHCFromNat Z = 0 + GHCFromNat (S n) = 1 GHC.+ GHCFromNat n + +-- | Convert an 'SNat' back to a GHC 'GHC.SNat'. If you use 'recoverGHC' after +-- 'mkSNatFromGHC', you will end up with an 'GHC.SNat (GHCFromNat (NatFromGHC +-- ghcn))'; use 'lemGHCToGHC' to rewrite that back to 'GHC.SNat ghcn'. +recoverGHC :: forall n. SNat n -> GHC.SNat (GHCFromNat n) +recoverGHC (MkSNatUnsafe n) = + GHC.withSomeSNat n $ \(ghcn :: GHC.SNat m) -> + unsafeCoerce @(GHC.SNat m) @(GHC.SNat (GHCFromNat n)) ghcn + +-- | 'GHCFromNat' and 'NatFromGHC' are inverses (first half). +lemGHCNatGHC :: GHCFromNat (NatFromGHC ghcn) :~: ghcn +lemGHCNatGHC = unsafeCoerceRefl + +-- | 'GHCFromNat' and 'NatFromGHC' are inverses (second half). +lemNatGHCNat :: NatFromGHC (GHCFromNat n) :~: n +lemNatGHCNat = unsafeCoerceRefl + +pattern SNat' :: forall n. () => GHC.KnownNat (GHCFromNat n) => SNat n +pattern SNat' <- (recoverGHC -> GHC.SNat) + where SNat' = case lemNatGHCNat @n of Refl -> mkSNat @(GHCFromNat n) +{-# COMPLETE SNat' #-} + +-- | Add type-level Peano naturals. +type family n + m where + Z + m = m + S n + m = S (n + m) + +-- | Add 'SNat's. +snatAdd :: SNat n -> SNat m -> SNat (n + m) +snatAdd (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n + m) + +-- | Subtract type-level Peano naturals. Does not reduce if the result would be negative. +type family n - m where + n - Z = n + S n - S m = n - m + +-- | Subtract 'SNat's. Returns 'Nothing' if the result would be negative. +snatSub :: SNat n -> SNat m -> Maybe (SNat (n - m)) +snatSub (MkSNatUnsafe n) (MkSNatUnsafe m) + | n >= m = Just (MkSNatUnsafe (n - m)) + | otherwise = Nothing + +-- | Multiply type-level Peano naturals. +type family n * m where + Z * m = Z + S n * m = m + n * m + +-- | Multiply 'SNat's. +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n * m) + +type family Compare n m where + Compare Z Z = EQ + Compare Z (S m) = LT + Compare (S n) Z = GT + Compare (S n) (S m) = Compare n m + +data SOrdering n m where + SLT :: Compare n m ~ LT => SOrdering n m + SEQ :: Compare n n ~ EQ => SOrdering n n + SGT :: Compare n m ~ GT => SOrdering n m + +snatCompare :: SNat n -> SNat m -> SOrdering n m +snatCompare (MkSNatUnsafe @n n) (MkSNatUnsafe @m m) = case compare n m of + LT | Refl <- unsafeCoerceRefl @(Compare n m) @LT -> + SLT + EQ | Refl <- unsafeCoerceRefl @n @m + , Refl <- unsafeCoerceRefl @(Compare n n) @EQ -> + SEQ + GT | Refl <- unsafeCoerceRefl @(Compare n m) @GT -> + SGT + +type family OrdCond ord lt eq gt where + OrdCond LT lt eq gt = lt + OrdCond EQ lt eq gt = eq + OrdCond GT lt eq gt = gt + +type n <= m = OrdCond (Compare n m) True True False ~ True +type n < m = Compare n m ~ LT +type n >= m = OrdCond (Compare n m) False True True ~ True +type n > m = Compare n m ~ GT + +-- | Pass an 'SNat' implicitly, in a constraint. +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +unsafeCoerceRefl :: forall a b. a :~: b +unsafeCoerceRefl = unsafeCoerce Refl |