diff options
| -rw-r--r-- | ox-arrays.cabal | 2 | ||||
| -rw-r--r-- | src/Array.hs | 36 | ||||
| -rw-r--r-- | src/Fancy.hs | 37 | ||||
| -rw-r--r-- | src/Nats.hs | 54 | 
4 files changed, 87 insertions, 42 deletions
| diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 0c74972..8193d8d 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -9,10 +9,10 @@ library    exposed-modules:      Array      Fancy +    Nats    build-depends:      base >=4.18,      ghc-typelits-knownnat, -    ghc-typelits-natnormalise,      orthotope,      vector    hs-source-dirs: src diff --git a/src/Array.hs b/src/Array.hs index 25db19e..5140eaf 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -8,8 +8,6 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Array where  import qualified Data.Array.RankedU as U @@ -17,12 +15,10 @@ import Data.Kind  import Data.Proxy  import Data.Type.Equality  import qualified Data.Vector.Unboxed as VU -import GHC.TypeLits  import Unsafe.Coerce (unsafeCoerce) +import Nats -data Dict c a where -  Dict :: c a => Dict c a  type family l1 ++ l2 where    '[] ++ l2 = l2 @@ -54,16 +50,16 @@ class KnownShapeX sh where  instance KnownShapeX '[] where    knownShapeX = SZX  instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where -  knownShapeX = natSing @n :$@ knownShapeX +  knownShapeX = knownNat :$@ knownShapeX  instance KnownShapeX sh => KnownShapeX (Nothing : sh) where    knownShapeX = () :$? knownShapeX  type family Rank sh where -  Rank '[] = 0 -  Rank (_ : sh) = 1 + Rank sh +  Rank '[] = Z +  Rank (_ : sh) = S (Rank sh)  type XArray :: [Maybe Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (Rank sh) a) +data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)  zeroIdx :: StaticShapeX sh -> IxX sh  zeroIdx SZX = IZX @@ -150,27 +146,33 @@ lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict  lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh  lemKnownShapeX SZX = Dict -lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict  lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma"  lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)  lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma" +lemAppKnownShapeX (n :$@ ssh) ssh' +  | Dict <- lemAppKnownShapeX ssh ssh' +  , Dict <- snatKnown n +  = Dict +lemAppKnownShapeX (() :$? ssh) ssh' +  | Dict <- lemAppKnownShapeX ssh ssh' +  = Dict  shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh  shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)    where      go :: StaticShapeX sh' -> [Int] -> IxX sh'      go SZX [] = IZX -    go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l +    go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l      go (() :$? ssh) (n : l) = n ::? go ssh l      go _ _ = error "Invalid shapeL" -fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a  -fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v) +fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector sh v +  | Dict <- lemKnownNatRank sh +  , Dict <- gknownNat (Proxy @(Rank sh)) +  = XArray (U.fromVector (shapeLshape sh) v)  toVector :: U.Unbox a => XArray sh a -> VU.Vector a  toVector (XArray arr) = U.toVector arr diff --git a/src/Fancy.hs b/src/Fancy.hs index 41272f0..8019393 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,5 +1,4 @@  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PolyKinds #-} @@ -8,9 +7,6 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} --- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  module Fancy where  import Control.Monad (forM_) @@ -21,16 +17,15 @@ import Data.Type.Equality  import Data.Type.Ord  import qualified Data.Vector.Unboxed as VU  import qualified Data.Vector.Unboxed.Mutable as VUM -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce)  import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))  import qualified Array as X +import Nats  type family Replicate n a where -  Replicate 0 a = '[] -  Replicate n a = a : Replicate (n - 1) a +  Replicate Z a = '[] +  Replicate (S n) a = a : Replicate n a  type family MapJust l where    MapJust '[] = '[] @@ -39,18 +34,12 @@ type family MapJust l where  lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a  lemCompareFalse1 = error "Incoherence" -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))    where -    go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) -    go SNat = case cmpNat (Proxy @1) (Proxy @m) of -                LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) -                EQI -> () :$? SZX -                GTI -> case cmpNat (Proxy @0) (Proxy @m) of -                         LTI -> lemCompareFalse1 (Proxy @m) -                         EQI -> SZX -                         GTI -> error "0 > natural" -    go _ = error "COMPLETE" +    go :: SNat m -> StaticShapeX (Replicate m Nothing) +    go SZ = SZX +    go (SS n) = () :$? go n  type Mixed :: [Maybe Nat] -> Type -> Type @@ -177,7 +166,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where    -- TODO: this is quadratic in the nesting level    mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh    mshape (M_Nest arr) -    | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') +    | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')      = ixAppPrefix (knownShapeX @sh) (mshape arr)      where        ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 @@ -213,7 +202,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where                      -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)                      -> ST s ()    mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) -    | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) +    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))      , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')      = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs @@ -251,13 +240,13 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe  instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where -  mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr +  mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr  type IxR :: Nat -> Type  data IxR n where -  IZR :: IxR 0 -  (:::) :: Int -> IxR n -> IxR (n + 1) +  IZR :: IxR Z +  (:::) :: Int -> IxR n -> IxR (S n)  type IxS :: [Nat] -> Type  data IxS sh where diff --git a/src/Nats.hs b/src/Nats.hs new file mode 100644 index 0000000..a9ad47c --- /dev/null +++ b/src/Nats.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Nats where + +import Data.Proxy +import Numeric.Natural +import qualified GHC.TypeLits as G + + +data Dict c a where +  Dict :: c a => Dict c a + +data Nat = Z | S Nat +  deriving (Show) + +data SNat n where +  SZ :: SNat Z +  SS :: SNat n -> SNat (S n) +deriving instance Show (SNat n) + +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +unSNat :: SNat n -> Natural +unSNat SZ = 0 +unSNat (SS n) = 1 + unSNat n + +unNat :: Nat -> Natural +unNat Z = 0 +unNat (S n) = 1 + unNat n + +snatKnown :: SNat n -> Dict KnownNat n +snatKnown SZ = Dict +snatKnown (SS n) | Dict <- snatKnown n = Dict + +type family GNat n where +  GNat Z = 0 +  GNat (S n) = 1 G.+ GNat n + +gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n) +gknownNat (Proxy @n) = go (knownNat @n) +  where +    go :: SNat m -> Dict G.KnownNat (GNat m) +    go SZ = Dict +    go (SS n) | Dict <- go n = Dict | 
