diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-02-11 00:11:53 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-15 11:06:40 +0100 | 
| commit | e6c20868375d2b7f6b31808844e1b48f78bca069 (patch) | |
| tree | 5e3c3efa5c61eb11a28b486bccbbcac823a36614 | |
| parent | c705bb4cf76d2e80f3e9ed900f901b697b378f79 (diff) | |
WIP half-peano SNatspeano-snat
| -rw-r--r-- | ox-arrays.cabal | 1 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 129 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 88 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 41 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Types.hs | 67 | ||||
| -rw-r--r-- | src/Data/SNat/Peano.hs | 232 | 
6 files changed, 358 insertions, 200 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 1253956..376107b 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -35,6 +35,7 @@ library      Data.Array.Nested.Internal.Ranked      Data.Array.Nested.Internal.Shape      Data.Array.Nested.Internal.Shaped +    Data.SNat.Peano    if flag(trace-wrappers)      exposed-modules: diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a24efd6..dd4d8e7 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -24,14 +24,13 @@ import Foreign.C.Types  import Foreign.Marshal.Alloc (alloca)  import Foreign.Ptr  import Foreign.Storable (Storable(sizeOf), peek, poke) -import GHC.TypeLits  import GHC.TypeNats qualified as TypeNats  import Language.Haskell.TH  import System.IO.Unsafe +import Data.SNat.Peano  import Data.Array.Mixed.Internal.Arith.Foreign  import Data.Array.Mixed.Internal.Arith.Lists -import Data.Array.Mixed.Types (fromSNat')  -- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition @@ -41,8 +40,8 @@ import Data.Array.Mixed.Types (fromSNat')  liftVEltwise1 :: (Storable a, Storable b)                => SNat n                -> (VS.Vector a -> VS.Vector b) -              -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +              -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) b +liftVEltwise1 SNat' f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | Just (blockOff, blockSz) <- stridesDense sh offset strides =        let vec' = f (VS.slice blockOff blockSz vec)        in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) @@ -52,8 +51,8 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))  liftVEltwise2 :: (Storable a, Storable b, Storable c)                => SNat n                -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) -              -> RS.Array n a -> RS.Array n b -> RS.Array n c -liftVEltwise2 SNat f +              -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) b -> RS.Array (GHCFromNat n) c +liftVEltwise2 SNat' f      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))    | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 @@ -172,8 +171,8 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a)                   -> (Ptr a -> Ptr b)                   -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant                   -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel -                 -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) +                 -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a +vectorRedInnerOp sn@SNat' valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))    | null sh = error "unreachable"    | last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0])    | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty)) @@ -210,7 +209,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride                 VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->                   VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->                     fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR) -           TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> +           TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(TypeNats.SNat :: TypeNats.SNat lenFm1) ->               RS.stretch (init sh)  -- replicate to original shape                 . RS.reshape (init shOnes)  -- add 1-sized dimensions where the original was replicated                 . RS.rev (map fst (filter snd (zip [0..] revDims)))  -- re-reverse the correct dimensions @@ -226,7 +225,7 @@ vectorRedFullOp :: forall a b n. (Num a, Storable a)                  -> (b -> a)                  -> (Ptr a -> Ptr b)                  -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel -                -> RS.Array n a -> a +                -> RS.Array (GHCFromNat n) a -> a  vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec)))    | null sh = vec VS.! offset  -- 0D array has one element    | any (<= 0) sh = 0 @@ -309,12 +308,12 @@ vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)                       => SNat n                       -> (a -> b)                       -> (Ptr a -> Ptr b) -                     -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a)  -- ^ elementwise multiplication +                     -> (SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a)  -- ^ elementwise multiplication                       -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant                       -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel                       -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel -                     -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a -vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner +                     -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a +vectorDotprodInnerOp sn@SNat' valconv ptrconv fmul fscale fred fdotinner      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))    | null sh1 || null sh2 = error "unreachable" @@ -344,7 +343,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner            fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)                      pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))                      pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) -      RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv +      RS.fromVector @_ @(GHCFromNat n) (init sh1) <$> VS.unsafeFreeze outv  {-# NOINLINE dotScalarVector #-}  dotScalarVector :: forall a b. (Num a, Storable a) @@ -398,7 +397,10 @@ $(fmap concat . forM typesList $ \arithtype -> do            c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))            c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))        sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +                     [t| forall n. SNat n +                                -> RS.Array (GHCFromNat n) $ttyp +                                -> RS.Array (GHCFromNat n) $ttyp +                                -> RS.Array (GHCFromNat n) $ttyp |]                 ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]                     return $ FunD name [Clause [] (NormalB body) []]]) @@ -412,7 +414,10 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do            c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))            c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))        sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +                     [t| forall n. SNat n +                                -> RS.Array (GHCFromNat n) $ttyp +                                -> RS.Array (GHCFromNat n) $ttyp +                                -> RS.Array (GHCFromNat n) $ttyp |]                 ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]                     return $ FunD name [Clause [] (NormalB body) []]]) @@ -422,7 +427,7 @@ $(fmap concat . forM typesList $ \arithtype -> do        let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))            c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))        sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +                     [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]                 ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]                     return $ FunD name [Clause [] (NormalB body) []]]) @@ -432,7 +437,7 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do        let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))            c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))        sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +                     [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]                 ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]                     return $ FunD name [Clause [] (NormalB body) []]]) @@ -451,11 +456,11 @@ $(fmap concat . forM typesList $ \arithtype -> do            c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))            c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))        sequence [SigD name1 <$> -                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +                     [t| forall n. SNat n -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]                 ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]                     return $ FunD name1 [Clause [] (NormalB body) []]                 ,SigD namefull <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |] +                     [t| forall n. SNat n -> RS.Array (GHCFromNat n) $ttyp -> $ttyp |]                 ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]                     return $ FunD namefull [Clause [] (NormalB body) []]                 ]) @@ -478,7 +483,7 @@ $(fmap concat . forM typesList $ \arithtype -> do          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))          c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))      sequence [SigD name <$> -                   [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +                   [t| forall n. SNat n -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat (S n)) $ttyp -> RS.Array (GHCFromNat n) $ttyp |]               ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]                   return $ FunD name [Clause [] (NormalB body) []]]) @@ -487,7 +492,7 @@ $(fmap concat . forM typesList $ \arithtype -> do  intWidBranch1 :: forall i n. (FiniteBits i, Storable i)                => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -              -> (SNat n -> RS.Array n i -> RS.Array n i) +              -> (SNat n -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i)  intWidBranch1 f32 f64 sn    | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)    | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) @@ -503,7 +508,7 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- sv                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv -              -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +              -> (SNat n -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i -> RS.Array (GHCFromNat n) i)  intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn    | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)    | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) @@ -516,7 +521,7 @@ intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)                      -- int64                   -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant                   -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel -                 -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +                 -> (SNat n -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat n) i)  intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn    | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32    | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 @@ -528,7 +533,7 @@ intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)                      -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32)  -- ^ reduction kernel                         -- int64                      -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64)  -- ^ reduction kernel -                    -> (SNat n -> RS.Array n i -> i) +                    -> (SNat n -> RS.Array (GHCFromNat n) i -> i)  intWidBranchRedFull fsc fred32 fred64 sn    | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32    | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 @@ -554,26 +559,26 @@ intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt                      -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant                      -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel                      -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ dotprod kernel -                    -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i) +                    -> (SNat n -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat (S n)) i -> RS.Array (GHCFromNat n) i)  intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn    | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32    | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64    | otherwise = error "Unsupported Int width"  class NumElt a where -  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a -  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  numEltSumFull :: SNat n -> RS.Array n a -> a -  numEltProductFull :: SNat n -> RS.Array n a -> a -  numEltMinIndex :: SNat n -> RS.Array n a -> [Int] -  numEltMaxIndex :: SNat n -> RS.Array n a -> [Int] -  numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a +  numEltAdd :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  numEltSub :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  numEltMul :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  numEltNeg :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  numEltAbs :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  numEltSignum :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  numEltSum1Inner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a +  numEltProduct1Inner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a +  numEltSumFull :: SNat n -> RS.Array (GHCFromNat n) a -> a +  numEltProductFull :: SNat n -> RS.Array (GHCFromNat n) a -> a +  numEltMinIndex :: SNat n -> RS.Array (GHCFromNat n) a -> [Int] +  numEltMaxIndex :: SNat n -> RS.Array (GHCFromNat n) a -> [Int] +  numEltDotprodInner :: SNat n -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat (S n)) a -> RS.Array (GHCFromNat n) a  instance NumElt Int32 where    numEltAdd = addVectorInt32 @@ -688,29 +693,29 @@ instance NumElt CInt where                                                   (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64  class FloatElt a where -  floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltDiv :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltPow :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltLogbase :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltRecip :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltExp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltLog :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltSqrt :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltSin :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltCos :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltTan :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltAsin :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltAcos :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltAtan :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltSinh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltCosh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltTanh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltAsinh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltAcosh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltAtanh :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltLog1p :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltExpm1 :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltLog1pexp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a +  floatEltLog1mexp :: SNat n -> RS.Array (GHCFromNat n) a -> RS.Array (GHCFromNat n) a  instance FloatElt Float where    floatEltDiv = divVectorFloat diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 331d5e0..e85e67f 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -13,8 +13,6 @@  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Mixed.Permutation where  import Data.Coerce (coerce) @@ -24,13 +22,12 @@ import Data.Maybe (fromMaybe)  import Data.Proxy  import Data.Type.Bool  import Data.Type.Equality -import Data.Type.Ord  import GHC.TypeError -import GHC.TypeLits -import GHC.TypeNats qualified as TN +import Numeric.Natural  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types +import Data.SNat.Peano  -- * Permutations @@ -46,18 +43,19 @@ deriving instance Show (Perm list)  deriving instance Eq (Perm list)  permRank :: Perm list -> SNat (Rank list) -permRank PNil = SNat -permRank (_ `PCons` l) | SNat <- permRank l = SNat +permRank PNil = SZ +permRank (_ `PCons` l) = SS (permRank l)  permFromList :: [Int] -> (forall list. Perm list -> r) -> r  permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case -  Just sn -> permFromList xs $ \list -> k (sn `PCons` list) -  Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x +permFromList (x : xs) k = +  withSomeSNat' x $ \sn -> +    permFromList xs $ \list -> +      k (sn `PCons` list)  permToList :: Perm list -> [Natural]  permToList PNil = mempty -permToList (x `PCons` l) = TN.fromSNat x : permToList l +permToList (x `PCons` l) = fromSNat x : permToList l  permToList' :: Perm list -> [Int]  permToList' = map fromIntegral . permToList @@ -68,48 +66,47 @@ permToList' = map fromIntegral . permToList  permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r  permCheckPermutation = \p k ->    let n = permRank p -  in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of +  in case (provePerm1 (Proxy @list) n p, provePerm2 SZ n p) of         (Just Refl, Just Refl) -> Just k         _ -> Nothing    where -    lemElemCount :: (0 <= n, Compare n m ~ LT) -                 => proxy n -> proxy m -> Elem n (Count 0 m) :~: True +    lemElemCount :: (Z <= n, n < m) +                 => proxy n -> proxy m -> Elem n (Count Z m) :~: True      lemElemCount _ _ = unsafeCoerceRefl -    lemCount :: (OrdCond (Compare i n) True False True ~ True) -             => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n +    lemCount :: i < n => proxy i -> proxy n -> Count i n :~: i : Count (S i) n      lemCount _ _ = unsafeCoerceRefl      lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True      lemElem _ _ = unsafeCoerceRefl      provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' -               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) +               -> Maybe (AllElem' is' (Count Z (Rank isTop)) :~: True)      provePerm1 _ _ PNil = Just (Refl) -    provePerm1 p rtop@SNat (PCons sn@SNat perm) +    provePerm1 p rtop (PCons sn perm)        | Just Refl <- provePerm1 p rtop perm -      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of -          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl -          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +      = case (snatCompare SZ sn, snatCompare sn rtop) of +          (SLT, SLT) | Refl <- lemElemCount sn rtop -> Just Refl +          (SEQ, SLT) | Refl <- lemElemCount sn rtop -> Just Refl            _ -> Nothing        | otherwise        = Nothing      provePerm2 :: SNat i -> SNat n -> Perm is'                 -> Maybe (AllElem' (Count i n) is' :~: True) -    provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> -      case cmpNat i n of -        EQI -> Just Refl -        LTI | Refl <- lemCount i n -            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm +    provePerm2 = \i n perm -> +      case snatCompare i n of +        SEQ -> Just Refl +        SLT | Refl <- lemCount i n +            , Just Refl <- provePerm2 (SS i) n perm              -> checkElem i perm              | otherwise -> Nothing -        GTI -> error "unreachable" +        SGT -> error "unreachable"        where          checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)          checkElem _ PNil = Nothing -        checkElem i@SNat (PCons k@SNat perm :: Perm is') = -          case sameNat i k of +        checkElem i (PCons k perm :: Perm is') = +          case testEquality i k of              Just Refl -> Just Refl              Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl                      | otherwise -> Nothing @@ -117,7 +114,7 @@ permCheckPermutation = \p k ->  -- | Utility class for generating permutations from type class information.  class KnownPerm l where makePerm :: Perm l  instance KnownPerm '[] where makePerm = PNil -instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = knownNat `PCons` makePerm  -- | Untyped permutations for ranked arrays  type PermR = [Int] @@ -139,13 +136,13 @@ type AllElem as bs = Assert (AllElem' as bs)  type family Count i n where    Count n n = '[] -  Count i n = i : Count (i + 1) n +  Count i n = i : Count (S i) n -type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) +type IsPermutation as = (AllElem as (Count Z (Rank as)), AllElem (Count Z (Rank as)) as)  type family Index i sh where -  Index 0 (n : sh) = n -  Index i (_ : sh) = Index (i - 1) sh +  Index Z (n : sh) = n +  Index (S i) (_ : sh) = Index i sh  type family Permute is sh where    Permute '[] sh = '[] @@ -178,9 +175,7 @@ listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =  listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)  listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) -  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = listxIndex p pT i sh +listxIndex p pT (SS i) (_ ::% sh) = listxIndex p pT i sh  listxIndex _ _ _ ZX = error "Index into empty shape"  listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f @@ -199,7 +194,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)  ssxPermute = coerce (listxPermute @(SMayNat () SNat))  ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) +ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)  ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)  ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) @@ -236,23 +231,23 @@ permInverse = \perm k ->        where          toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r          toHList [] k = k PNil -        toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) +        toHList (n : ns) k = toHList ns $ \l -> withSomeSNat n $ \sn -> k (PCons sn l)      provePermInverse :: Perm is -> Perm is' -> StaticShX sh                       -> Maybe (Permute is' (Permute is sh) :~: sh)      provePermInverse perm perminv ssh = -      ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh +      testEquality (ssxPermute perminv (ssxPermute perm ssh)) ssh  type family MapSucc is where    MapSucc '[] = '[] -  MapSucc (i : is) = i + 1 : MapSucc is +  MapSucc (i : is) = S i : MapSucc is -permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc +permShift1 :: Perm l -> Perm (Z : MapSucc l) +permShift1 = (SZ `PCons`) . permMapSucc    where      permMapSucc :: Perm l -> Perm (MapSucc l)      permMapSucc PNil = PNil -    permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns +    permMapSucc (i `PCons` ns) = SS i `PCons` permMapSucc ns  -- * Lemmas @@ -266,8 +261,3 @@ lemRankDropLen :: forall is sh. (Rank is <= Rank sh)  lemRankDropLen ZKX PNil = Refl  lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl  lemRankDropLen (_ :!% _) PNil = Refl -lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -             -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index e5f8b67..bed812d 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -18,8 +18,6 @@  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-}  {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Mixed.Shape where  import Control.DeepSeq (NFData(..)) @@ -35,16 +33,16 @@ import GHC.Exts (withDict)  import GHC.Generics (Generic)  import GHC.IsList (IsList)  import GHC.IsList qualified as IsList -import GHC.TypeLits  import Data.Array.Mixed.Types +import Data.SNat.Peano  -- | The length of a type-level list. If the argument is a shape, then the  -- result is the rank of that shape.  type family Rank sh where -  Rank '[] = 0 -  Rank (_ : sh) = Rank sh + 1 +  Rank '[] = Z +  Rank (_ : sh) = S (Rank sh)  -- * Mixed lists @@ -91,8 +89,8 @@ listxLength :: ListX sh f -> Int  listxLength = getSum . listxFold (\_ -> Sum 1)  listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat +listxRank ZX = SZ +listxRank (_ ::% l) = SS (listxRank l)  listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS  listxShow f l = showString "[" . go "" l . showString "]" @@ -255,7 +253,7 @@ type family AddMaybe n m where  smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)  smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)  smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatAdd n m)  -- | This is a newtype over 'ListX'. @@ -288,7 +286,7 @@ instance Functor (ShX sh) where  instance NFData i => NFData (ShX sh i) where    rnf (ShX ZX) = ()    rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) -  rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) +  rnf (ShX (SKnown n ::% l)) = rnf n `seq` rnf (ShX l)  shxLength :: ShX sh i -> Int  shxLength (ShX l) = listxLength l @@ -300,8 +298,8 @@ shxRank (ShX list) = listxRank list  -- dimensions) are the same.  shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')  shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') -  | Just Refl <- sameNat n m +shxEqual (SKnown n :$% sh) (SKnown m :$% sh') +  | Just Refl <- testEquality n m    , Just Refl <- shxEqual sh sh'    = Just Refl  shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') @@ -422,19 +420,6 @@ instance TestEquality StaticShX where  ssxLength :: StaticShX sh -> Int  ssxLength (StaticShX l) = listxLength l --- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ --- class of the @some@ package. -ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -ssxGeq ZKX ZKX = Just Refl -ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') -  | Just Refl <- sameNat n m -  , Just Refl <- ssxGeq sh sh' -  = Just Refl -ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') -  | Just Refl <- ssxGeq sh sh' -  = Just Refl -ssxGeq _ _ = Nothing -  ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')  ssxAppend ZKX sh' = sh'  ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' @@ -481,7 +466,7 @@ ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1  type KnownShX :: [Maybe Nat] -> Constraint  class KnownShX sh where knownShX :: StaticShX sh  instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown knownNat :!% knownShX  instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX  withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r @@ -490,7 +475,7 @@ withKnownShX sh = withDict @(KnownShX sh) sh  -- * Flattening -type Flatten sh = Flatten' 1 sh +type Flatten sh = Flatten' (S Z) sh  type family Flatten' acc sh where    Flatten' acc '[] = Just acc @@ -499,7 +484,7 @@ type family Flatten' acc sh where  -- This function is currently unused  ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) -ssxFlatten = go (SNat @1) +ssxFlatten = go (mkSNat @1)    where      go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)      go acc ZKX = SKnown acc @@ -507,7 +492,7 @@ ssxFlatten = go (SNat @1)      go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh  shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) -shxFlatten = go (SNat @1) +shxFlatten = go (mkSNat @1)    where      go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)      go acc ZSX = SKnown acc diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index 13675d0..319246b 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -1,27 +1,14 @@ +{-# LANGUAGE AllowAmbiguousTypes #-}  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Mixed.Types (    -- * Reified evidence of a type class    Dict(..), -  -- * Type-level naturals -  pattern SZ, pattern SS, -  fromSNat', sameNat', -  snatPlus, snatMinus, snatMul, -  snatSucc, -    -- * Type-level lists    type (++),    Replicate, @@ -36,67 +23,25 @@ module Data.Array.Mixed.Types (  ) where  import Data.Type.Equality -import Data.Proxy -import GHC.TypeLits -import GHC.TypeNats qualified as TN  import Unsafe.Coerce qualified +import Data.SNat.Peano +  -- | Evidence for the constraint @c a@.  data Dict c a where    Dict :: c a => Dict c a -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) -sameNat' n@SNat m@SNat = sameNat n m - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) -  where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) -  where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = -  withKnownNat snp1 $ -    case cmpNat (Proxy @1) (Proxy @np1) of -      LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) -      EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) -      GTI -> Nothing - --- This should be a function in base -snatPlus :: SNat n -> SNat m -> SNat (n + m) -snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMinus :: SNat n -> SNat m -> SNat (n - m) -snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMul :: SNat n -> SNat m -> SNat (n * m) -snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - -  -- | Type-level list append.  type family l1 ++ l2 where    '[] ++ l2 = l2    (x : xs) ++ l2 = x : xs ++ l2  type family Replicate n a where -  Replicate 0 a = '[] -  Replicate n a = a : Replicate (n - 1) a +  Replicate Z a = '[] +  Replicate (S n) a = a : Replicate n a -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (S n) a  lemReplicateSucc = unsafeCoerceRefl  type family MapJust l where diff --git a/src/Data/SNat/Peano.hs b/src/Data/SNat/Peano.hs new file mode 100644 index 0000000..a5109fa --- /dev/null +++ b/src/Data/SNat/Peano.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.SNat.Peano ( +  -- * Singleton naturals +  Nat(..), +  SNat(SZ, SS), +  mkSNat, mkSNatFromGHC, NatFromGHC, +  withSomeSNat, withSomeSNat', +  fromSNat, fromSNat', + +  -- * Computing with 'SNat' values +  type (+), snatAdd, +  type (-), snatSub, +  type (*), snatMul, +  Compare, SOrdering(..), snatCompare, OrdCond, type (<=), type (<), type (>=), type (>), + +  -- * 'KnownNat' +  KnownNat(..), + +  -- * Interoperate with GHC naturals +  pattern SNat', recoverGHC, GHCFromNat, lemGHCNatGHC, lemNatGHCNat, +) where + +import Control.DeepSeq +import Data.Proxy +import Data.Type.Equality +import Numeric.Natural +import qualified GHC.TypeNats as GHC +import Unsafe.Coerce + + +-- | Type-level Peano naturals. +type data Nat = Z | S Nat + +-- | A singleton for 'Nat'. The representation is just a 'Natural', not an +-- actual Peano natural / linked list. +newtype SNat (n :: Nat) = MkSNatUnsafe Natural + +instance Show (SNat n) where +  showsPrec d (MkSNatUnsafe n) = showParen (d > 10) $ +    showString ("mkSNat @" ++ show n) + +-- these are vacuous because it's a singleton +instance Eq (SNat n) where _ == _ = True +instance Ord (SNat n) where compare _ _ = EQ + +instance NFData (SNat n) where +  rnf (MkSNatUnsafe n) = rnf n + +instance TestEquality SNat where +  testEquality (MkSNatUnsafe n) (MkSNatUnsafe m) +    | n == m = Just unsafeCoerceRefl +    | otherwise = Nothing + +-- | The zero natural. Read this as: 'SZ :: SNat Z' +pattern SZ :: forall n. () => n ~ Z => SNat n +pattern SZ <- ((\(MkSNatUnsafe k) -> (# k, unsafeCoerceRefl @n @Z #)) +               -> (# 0, Refl #)) +  where SZ = MkSNatUnsafe 0 + +-- | /n/ plus one. Read this as: 'SS :: SNat n -> SNat (S n)' +pattern SS :: forall n. () => forall predn. S predn ~ n => SNat predn -> SNat n +pattern SS n <- (mkPredecessor -> Just (Predecessor n)) +  where SS (MkSNatUnsafe n) = MkSNatUnsafe (n + 1) +-- A little experiment showed that mkPredecessor inlines sufficiently that no +-- Predecessor object every gets allocated. Let's hope this is also true in +-- more practical situations. + +{-# COMPLETE SZ, SS #-} + +data Predecessor n = forall predn. S predn ~ n => Predecessor (SNat predn) + +mkPredecessor :: forall n. SNat n -> Maybe (Predecessor n) +mkPredecessor (MkSNatUnsafe 0) = Nothing +mkPredecessor (MkSNatUnsafe k) = Just (yolo (MkSNatUnsafe (k-1))) +  where +    yolo :: forall m predn. SNat predn -> Predecessor m +    yolo n | Refl <- unsafeCoerceRefl @(S predn) @m = Predecessor n + +-- | Convert a GHC type-level 'GHC.Nat' to a type-level Peano natural. Because +-- this type family performs induction on a GHC 'GHC.Nat', which only works +-- sensibly if the 'GHC.Nat' is monomorphic, using this on a bare type variable +-- will probably be unsuccessful. +type family NatFromGHC ghcn where +  NatFromGHC 0 = Z +  NatFromGHC n = S (NatFromGHC (n GHC.- 1)) + +-- | Convenience function to create an 'SNat'. Use with @-XDataKinds@ and +-- @-XTypeApplications@ like: +-- +-- >>> mkSNat @5 +-- +-- The 'GHC.KnownNat' constraint is automatically satisfied for any statically +-- known number. To construct an 'SNat' dynamically, you probably want +-- 'mkSNatFromGHC', or perhaps iterated 'SS'. +mkSNat :: forall ghcn. GHC.KnownNat ghcn => SNat (NatFromGHC ghcn) +mkSNat = MkSNatUnsafe (GHC.natVal (Proxy @ghcn)) + +-- | Convert a GHC 'GHC.SNat' to an 'SNat'. You can convert back using +-- the 'SNat'' pattern synonym, or more manually using 'recoverGHC'. +mkSNatFromGHC :: GHC.SNat ghcn -> SNat (NatFromGHC ghcn) +mkSNatFromGHC sn@GHC.SNat = MkSNatUnsafe (GHC.natVal sn) + +-- | Dynamically create an 'SNat' from an untyped 'Natural'. +withSomeSNat :: Natural -> (forall n. SNat n -> r) -> r +withSomeSNat n k = k (MkSNatUnsafe n) + +-- | Dynamically create an 'SNat' from an untyped 'Int'. Throws an exception if +-- the argument is negative. +withSomeSNat' :: Int -> (forall n. SNat n -> r) -> r +withSomeSNat' n k +  | n < 0 = error $ "withSomeSNat': " ++ show n ++ " is negative" +  | otherwise = k (MkSNatUnsafe (fromIntegral n)) + +-- | Get the untyped 'Natural' corresponding to the 'SNat'. +fromSNat :: SNat n -> Natural +fromSNat (MkSNatUnsafe n) = n + +-- | Unsafe! If @n@ is out of range for @Int@, this will simply wrap, not throw +-- an error! +fromSNat' :: SNat n -> Int +fromSNat' (MkSNatUnsafe n) = fromIntegral n + +-- | Convert a type-level Peano natural to a GHC type-level 'GHC.Nat'. +type family GHCFromNat n where +  GHCFromNat Z = 0 +  GHCFromNat (S n) = 1 GHC.+ GHCFromNat n + +-- | Convert an 'SNat' back to a GHC 'GHC.SNat'. If you use 'recoverGHC' after +-- 'mkSNatFromGHC', you will end up with an 'GHC.SNat (GHCFromNat (NatFromGHC +-- ghcn))'; use 'lemGHCToGHC' to rewrite that back to 'GHC.SNat ghcn'. +recoverGHC :: forall n. SNat n -> GHC.SNat (GHCFromNat n) +recoverGHC (MkSNatUnsafe n) = +  GHC.withSomeSNat n $ \(ghcn :: GHC.SNat m) -> +    unsafeCoerce @(GHC.SNat m) @(GHC.SNat (GHCFromNat n)) ghcn + +-- | 'GHCFromNat' and 'NatFromGHC' are inverses (first half). +lemGHCNatGHC :: GHCFromNat (NatFromGHC ghcn) :~: ghcn +lemGHCNatGHC = unsafeCoerceRefl + +-- | 'GHCFromNat' and 'NatFromGHC' are inverses (second half). +lemNatGHCNat :: NatFromGHC (GHCFromNat n) :~: n +lemNatGHCNat = unsafeCoerceRefl + +pattern SNat' :: forall n. () => GHC.KnownNat (GHCFromNat n) => SNat n +pattern SNat' <- (recoverGHC -> GHC.SNat) +  where SNat' = case lemNatGHCNat @n of Refl -> mkSNat @(GHCFromNat n) +{-# COMPLETE SNat' #-} + +-- | Add type-level Peano naturals. +type family n + m where +  Z + m = m +  S n + m = S (n + m) + +-- | Add 'SNat's. +snatAdd :: SNat n -> SNat m -> SNat (n + m) +snatAdd (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n + m) + +-- | Subtract type-level Peano naturals. Does not reduce if the result would be negative. +type family n - m where +  n - Z = n +  S n - S m = n - m + +-- | Subtract 'SNat's. Returns 'Nothing' if the result would be negative. +snatSub :: SNat n -> SNat m -> Maybe (SNat (n - m)) +snatSub (MkSNatUnsafe n) (MkSNatUnsafe m) +  | n >= m = Just (MkSNatUnsafe (n - m)) +  | otherwise  = Nothing + +-- | Multiply type-level Peano naturals. +type family n * m where +  Z * m = Z +  S n * m = m + n * m + +-- | Multiply 'SNat's. +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul (MkSNatUnsafe n) (MkSNatUnsafe m) = MkSNatUnsafe (n * m) + +type family Compare n m where +  Compare Z Z = EQ +  Compare Z (S m) = LT +  Compare (S n) Z = GT +  Compare (S n) (S m) = Compare n m + +data SOrdering n m where +  SLT :: Compare n m ~ LT => SOrdering n m +  SEQ :: Compare n n ~ EQ => SOrdering n n +  SGT :: Compare n m ~ GT => SOrdering n m + +snatCompare :: SNat n -> SNat m -> SOrdering n m +snatCompare (MkSNatUnsafe @n n) (MkSNatUnsafe @m m) = case compare n m of +  LT | Refl <- unsafeCoerceRefl @(Compare n m) @LT -> +    SLT +  EQ | Refl <- unsafeCoerceRefl @n @m +     , Refl <- unsafeCoerceRefl @(Compare n n) @EQ -> +    SEQ +  GT | Refl <- unsafeCoerceRefl @(Compare n m) @GT -> +    SGT + +type family OrdCond ord lt eq gt where +  OrdCond LT lt eq gt = lt +  OrdCond EQ lt eq gt = eq +  OrdCond GT lt eq gt = gt + +type n <= m = OrdCond (Compare n m) True True False ~ True +type n < m = Compare n m ~ LT +type n >= m = OrdCond (Compare n m) False True True ~ True +type n > m = Compare n m ~ GT + +-- | Pass an 'SNat' implicitly, in a constraint. +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +unsafeCoerceRefl :: forall a b. a :~: b +unsafeCoerceRefl = unsafeCoerce Refl  | 
