diff options
Diffstat (limited to 'src/Fancy.hs')
| -rw-r--r-- | src/Fancy.hs | 37 | 
1 files changed, 13 insertions, 24 deletions
| diff --git a/src/Fancy.hs b/src/Fancy.hs index 41272f0..8019393 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,5 +1,4 @@  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PolyKinds #-} @@ -8,9 +7,6 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} --- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  module Fancy where  import Control.Monad (forM_) @@ -21,16 +17,15 @@ import Data.Type.Equality  import Data.Type.Ord  import qualified Data.Vector.Unboxed as VU  import qualified Data.Vector.Unboxed.Mutable as VUM -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce)  import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))  import qualified Array as X +import Nats  type family Replicate n a where -  Replicate 0 a = '[] -  Replicate n a = a : Replicate (n - 1) a +  Replicate Z a = '[] +  Replicate (S n) a = a : Replicate n a  type family MapJust l where    MapJust '[] = '[] @@ -39,18 +34,12 @@ type family MapJust l where  lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a  lemCompareFalse1 = error "Incoherence" -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))    where -    go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) -    go SNat = case cmpNat (Proxy @1) (Proxy @m) of -                LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) -                EQI -> () :$? SZX -                GTI -> case cmpNat (Proxy @0) (Proxy @m) of -                         LTI -> lemCompareFalse1 (Proxy @m) -                         EQI -> SZX -                         GTI -> error "0 > natural" -    go _ = error "COMPLETE" +    go :: SNat m -> StaticShapeX (Replicate m Nothing) +    go SZ = SZX +    go (SS n) = () :$? go n  type Mixed :: [Maybe Nat] -> Type -> Type @@ -177,7 +166,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where    -- TODO: this is quadratic in the nesting level    mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh    mshape (M_Nest arr) -    | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') +    | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')      = ixAppPrefix (knownShapeX @sh) (mshape arr)      where        ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 @@ -213,7 +202,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where                      -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)                      -> ST s ()    mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) -    | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) +    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))      , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')      = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs @@ -251,13 +240,13 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe  instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where -  mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr +  mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr  type IxR :: Nat -> Type  data IxR n where -  IZR :: IxR 0 -  (:::) :: Int -> IxR n -> IxR (n + 1) +  IZR :: IxR Z +  (:::) :: Int -> IxR n -> IxR (S n)  type IxS :: [Nat] -> Type  data IxS sh where | 
