diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-03-27 10:51:25 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-03-27 10:51:25 +0100 | 
| commit | 95f48df1b97529311a41245bbaaf4781b5ffaa4b (patch) | |
| tree | 8a5bf9024ec132f702bfd6f025b06d3a5054c2dd | |
| parent | 9c98118e0a0ff9be463bc9e7afe4253a4de3d433 (diff) | |
GHC typenats are bad
| -rw-r--r-- | ox-arrays.cabal | 1 | ||||
| -rw-r--r-- | src/Fancy.hs | 34 | 
2 files changed, 33 insertions, 2 deletions
| diff --git a/ox-arrays.cabal b/ox-arrays.cabal index aea0a94..0c74972 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -12,6 +12,7 @@ library    build-depends:      base >=4.18,      ghc-typelits-knownnat, +    ghc-typelits-natnormalise,      orthotope,      vector    hs-source-dirs: src diff --git a/src/Fancy.hs b/src/Fancy.hs index 821073e..41272f0 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,14 +1,16 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE FlexibleInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  module Fancy where  import Control.Monad (forM_) @@ -16,9 +18,11 @@ import Control.Monad.ST  import Data.Kind  import Data.Proxy  import Data.Type.Equality +import Data.Type.Ord  import qualified Data.Vector.Unboxed as VU  import qualified Data.Vector.Unboxed.Mutable as VUM  import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce)  import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))  import qualified Array as X @@ -32,6 +36,22 @@ type family MapJust l where    MapJust '[] = '[]    MapJust (x : xs) = Just x : MapJust xs +lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a +lemCompareFalse1 = error "Incoherence" + +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) +  where +    go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) +    go SNat = case cmpNat (Proxy @1) (Proxy @m) of +                LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) +                EQI -> () :$? SZX +                GTI -> case cmpNat (Proxy @0) (Proxy @m) of +                         LTI -> lemCompareFalse1 (Proxy @m) +                         EQI -> SZX +                         GTI -> error "0 > natural" +    go _ = error "COMPLETE" +  type Mixed :: [Maybe Nat] -> Type -> Type  data family Mixed sh a @@ -223,6 +243,16 @@ newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)  type Shaped :: [Nat] -> Type -> Type  newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + + +instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where +  mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr +  type IxR :: Nat -> Type  data IxR n where | 
