summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-27 21:13:31 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-27 21:13:31 +0100
commitcb31e179971293c519a530d8ce8ccc004458b1c4 (patch)
treea760f9ca2ea4048f1410a2b24500560e35f8ab19 /src
parent95f48df1b97529311a41245bbaaf4781b5ffaa4b (diff)
Nats
Diffstat (limited to 'src')
-rw-r--r--src/Array.hs36
-rw-r--r--src/Fancy.hs37
-rw-r--r--src/Nats.hs54
3 files changed, 86 insertions, 41 deletions
diff --git a/src/Array.hs b/src/Array.hs
index 25db19e..5140eaf 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -8,8 +8,6 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Array where
import qualified Data.Array.RankedU as U
@@ -17,12 +15,10 @@ import Data.Kind
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
-import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
+import Nats
-data Dict c a where
- Dict :: c a => Dict c a
type family l1 ++ l2 where
'[] ++ l2 = l2
@@ -54,16 +50,16 @@ class KnownShapeX sh where
instance KnownShapeX '[] where
knownShapeX = SZX
instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = natSing @n :$@ knownShapeX
+ knownShapeX = knownNat :$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :$? knownShapeX
type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = 1 + Rank sh
+ Rank '[] = Z
+ Rank (_ : sh) = S (Rank sh)
type XArray :: [Maybe Nat] -> Type -> Type
-data XArray sh a = XArray (U.Array (Rank sh) a)
+data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
zeroIdx :: StaticShapeX sh -> IxX sh
zeroIdx SZX = IZX
@@ -150,27 +146,33 @@ lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
lemKnownShapeX SZX = Dict
-lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict
lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma"
lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict
-lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict
-lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma"
+lemAppKnownShapeX (n :$@ ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ , Dict <- snatKnown n
+ = Dict
+lemAppKnownShapeX (() :$? ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ = Dict
shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
- go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
+ go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
-fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v)
+fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector sh v
+ | Dict <- lemKnownNatRank sh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.fromVector (shapeLshape sh) v)
toVector :: U.Unbox a => XArray sh a -> VU.Vector a
toVector (XArray arr) = U.toVector arr
diff --git a/src/Fancy.hs b/src/Fancy.hs
index 41272f0..8019393 100644
--- a/src/Fancy.hs
+++ b/src/Fancy.hs
@@ -1,5 +1,4 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
@@ -8,9 +7,6 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
--- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Fancy where
import Control.Monad (forM_)
@@ -21,16 +17,15 @@ import Data.Type.Equality
import Data.Type.Ord
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
+import Nats
type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
+ Replicate Z a = '[]
+ Replicate (S n) a = a : Replicate n a
type family MapJust l where
MapJust '[] = '[]
@@ -39,18 +34,12 @@ type family MapJust l where
lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a
lemCompareFalse1 = error "Incoherence"
-lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
where
- go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing)
- go SNat = case cmpNat (Proxy @1) (Proxy @m) of
- LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1))
- EQI -> () :$? SZX
- GTI -> case cmpNat (Proxy @0) (Proxy @m) of
- LTI -> lemCompareFalse1 (Proxy @m)
- EQI -> SZX
- GTI -> error "0 > natural"
- go _ = error "COMPLETE"
+ go :: SNat m -> StaticShapeX (Replicate m Nothing)
+ go SZ = SZX
+ go (SS n) = () :$? go n
type Mixed :: [Maybe Nat] -> Type -> Type
@@ -177,7 +166,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
-- TODO: this is quadratic in the nesting level
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
- | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
+ | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
= ixAppPrefix (knownShapeX @sh) (mshape arr)
where
ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
@@ -213,7 +202,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
- | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
, Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
@@ -251,13 +240,13 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
- mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr
+ mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
type IxR :: Nat -> Type
data IxR n where
- IZR :: IxR 0
- (:::) :: Int -> IxR n -> IxR (n + 1)
+ IZR :: IxR Z
+ (:::) :: Int -> IxR n -> IxR (S n)
type IxS :: [Nat] -> Type
data IxS sh where
diff --git a/src/Nats.hs b/src/Nats.hs
new file mode 100644
index 0000000..a9ad47c
--- /dev/null
+++ b/src/Nats.hs
@@ -0,0 +1,54 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Nats where
+
+import Data.Proxy
+import Numeric.Natural
+import qualified GHC.TypeLits as G
+
+
+data Dict c a where
+ Dict :: c a => Dict c a
+
+data Nat = Z | S Nat
+ deriving (Show)
+
+data SNat n where
+ SZ :: SNat Z
+ SS :: SNat n -> SNat (S n)
+deriving instance Show (SNat n)
+
+class KnownNat n where knownNat :: SNat n
+instance KnownNat Z where knownNat = SZ
+instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+
+unSNat :: SNat n -> Natural
+unSNat SZ = 0
+unSNat (SS n) = 1 + unSNat n
+
+unNat :: Nat -> Natural
+unNat Z = 0
+unNat (S n) = 1 + unNat n
+
+snatKnown :: SNat n -> Dict KnownNat n
+snatKnown SZ = Dict
+snatKnown (SS n) | Dict <- snatKnown n = Dict
+
+type family GNat n where
+ GNat Z = 0
+ GNat (S n) = 1 G.+ GNat n
+
+gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n)
+gknownNat (Proxy @n) = go (knownNat @n)
+ where
+ go :: SNat m -> Dict G.KnownNat (GNat m)
+ go SZ = Dict
+ go (SS n) | Dict <- go n = Dict