diff options
-rw-r--r-- | ox-arrays.cabal | 2 | ||||
-rw-r--r-- | src/Array.hs | 36 | ||||
-rw-r--r-- | src/Fancy.hs | 37 | ||||
-rw-r--r-- | src/Nats.hs | 54 |
4 files changed, 87 insertions, 42 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 0c74972..8193d8d 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -9,10 +9,10 @@ library exposed-modules: Array Fancy + Nats build-depends: base >=4.18, ghc-typelits-knownnat, - ghc-typelits-natnormalise, orthotope, vector hs-source-dirs: src diff --git a/src/Array.hs b/src/Array.hs index 25db19e..5140eaf 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -8,8 +8,6 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Array where import qualified Data.Array.RankedU as U @@ -17,12 +15,10 @@ import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU -import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) +import Nats -data Dict c a where - Dict :: c a => Dict c a type family l1 ++ l2 where '[] ++ l2 = l2 @@ -54,16 +50,16 @@ class KnownShapeX sh where instance KnownShapeX '[] where knownShapeX = SZX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing @n :$@ knownShapeX + knownShapeX = knownNat :$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :$? knownShapeX type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = 1 + Rank sh + Rank '[] = Z + Rank (_ : sh) = S (Rank sh) type XArray :: [Maybe Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (Rank sh) a) +data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) zeroIdx :: StaticShapeX sh -> IxX sh zeroIdx SZX = IZX @@ -150,27 +146,33 @@ lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX SZX = Dict -lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma" lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma" +lemAppKnownShapeX (n :$@ ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + , Dict <- snatKnown n + = Dict +lemAppKnownShapeX (() :$? ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + = Dict shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IxX sh' go SZX [] = IZX - go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l + go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a -fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v) +fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.fromVector (shapeLshape sh) v) toVector :: U.Unbox a => XArray sh a -> VU.Vector a toVector (XArray arr) = U.toVector arr diff --git a/src/Fancy.hs b/src/Fancy.hs index 41272f0..8019393 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} @@ -8,9 +7,6 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} --- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Fancy where import Control.Monad (forM_) @@ -21,16 +17,15 @@ import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) import qualified Array as X +import Nats type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a + Replicate Z a = '[] + Replicate (S n) a = a : Replicate n a type family MapJust l where MapJust '[] = '[] @@ -39,18 +34,12 @@ type family MapJust l where lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a lemCompareFalse1 = error "Incoherence" -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) where - go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) - go SNat = case cmpNat (Proxy @1) (Proxy @m) of - LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) - EQI -> () :$? SZX - GTI -> case cmpNat (Proxy @0) (Proxy @m) of - LTI -> lemCompareFalse1 (Proxy @m) - EQI -> SZX - GTI -> error "0 > natural" - go _ = error "COMPLETE" + go :: SNat m -> StaticShapeX (Replicate m Nothing) + go SZ = SZX + go (SS n) = () :$? go n type Mixed :: [Maybe Nat] -> Type -> Type @@ -177,7 +166,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -- TODO: this is quadratic in the nesting level mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) - | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') + | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = ixAppPrefix (knownShapeX @sh) (mshape arr) where ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 @@ -213,7 +202,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs @@ -251,13 +240,13 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where - mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr + mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr type IxR :: Nat -> Type data IxR n where - IZR :: IxR 0 - (:::) :: Int -> IxR n -> IxR (n + 1) + IZR :: IxR Z + (:::) :: Int -> IxR n -> IxR (S n) type IxS :: [Nat] -> Type data IxS sh where diff --git a/src/Nats.hs b/src/Nats.hs new file mode 100644 index 0000000..a9ad47c --- /dev/null +++ b/src/Nats.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Nats where + +import Data.Proxy +import Numeric.Natural +import qualified GHC.TypeLits as G + + +data Dict c a where + Dict :: c a => Dict c a + +data Nat = Z | S Nat + deriving (Show) + +data SNat n where + SZ :: SNat Z + SS :: SNat n -> SNat (S n) +deriving instance Show (SNat n) + +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +unSNat :: SNat n -> Natural +unSNat SZ = 0 +unSNat (SS n) = 1 + unSNat n + +unNat :: Nat -> Natural +unNat Z = 0 +unNat (S n) = 1 + unNat n + +snatKnown :: SNat n -> Dict KnownNat n +snatKnown SZ = Dict +snatKnown (SS n) | Dict <- snatKnown n = Dict + +type family GNat n where + GNat Z = 0 + GNat (S n) = 1 G.+ GNat n + +gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n) +gknownNat (Proxy @n) = go (knownNat @n) + where + go :: SNat m -> Dict G.KnownNat (GNat m) + go SZ = Dict + go (SS n) | Dict <- go n = Dict |