diff options
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r-- | src/Fancy.hs | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs index 821073e..41272f0 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,14 +1,16 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE FlexibleInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Fancy where import Control.Monad (forM_) @@ -16,9 +18,11 @@ import Control.Monad.ST import Data.Kind import Data.Proxy import Data.Type.Equality +import Data.Type.Ord import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) import qualified Array as X @@ -32,6 +36,22 @@ type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a +lemCompareFalse1 = error "Incoherence" + +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) + where + go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) + go SNat = case cmpNat (Proxy @1) (Proxy @m) of + LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) + EQI -> () :$? SZX + GTI -> case cmpNat (Proxy @0) (Proxy @m) of + LTI -> lemCompareFalse1 (Proxy @m) + EQI -> SZX + GTI -> error "0 > natural" + go _ = error "COMPLETE" + type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a @@ -223,6 +243,16 @@ newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + + +instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where + mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr + type IxR :: Nat -> Type data IxR n where |