aboutsummaryrefslogtreecommitdiff
path: root/src/Fancy.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-27 21:13:31 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-27 21:13:31 +0100
commitcb31e179971293c519a530d8ce8ccc004458b1c4 (patch)
treea760f9ca2ea4048f1410a2b24500560e35f8ab19 /src/Fancy.hs
parent95f48df1b97529311a41245bbaaf4781b5ffaa4b (diff)
Nats
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r--src/Fancy.hs37
1 files changed, 13 insertions, 24 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs
index 41272f0..8019393 100644
--- a/src/Fancy.hs
+++ b/src/Fancy.hs
@@ -1,5 +1,4 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
@@ -8,9 +7,6 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
--- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Fancy where
import Control.Monad (forM_)
@@ -21,16 +17,15 @@ import Data.Type.Equality
import Data.Type.Ord
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
+import Nats
type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
+ Replicate Z a = '[]
+ Replicate (S n) a = a : Replicate n a
type family MapJust l where
MapJust '[] = '[]
@@ -39,18 +34,12 @@ type family MapJust l where
lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a
lemCompareFalse1 = error "Incoherence"
-lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
where
- go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing)
- go SNat = case cmpNat (Proxy @1) (Proxy @m) of
- LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1))
- EQI -> () :$? SZX
- GTI -> case cmpNat (Proxy @0) (Proxy @m) of
- LTI -> lemCompareFalse1 (Proxy @m)
- EQI -> SZX
- GTI -> error "0 > natural"
- go _ = error "COMPLETE"
+ go :: SNat m -> StaticShapeX (Replicate m Nothing)
+ go SZ = SZX
+ go (SS n) = () :$? go n
type Mixed :: [Maybe Nat] -> Type -> Type
@@ -177,7 +166,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
-- TODO: this is quadratic in the nesting level
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
- | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
+ | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
= ixAppPrefix (knownShapeX @sh) (mshape arr)
where
ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
@@ -213,7 +202,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
- | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
, Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
@@ -251,13 +240,13 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
- mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr
+ mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
type IxR :: Nat -> Type
data IxR n where
- IZR :: IxR 0
- (:::) :: Int -> IxR n -> IxR (n + 1)
+ IZR :: IxR Z
+ (:::) :: Int -> IxR n -> IxR (S n)
type IxS :: [Nat] -> Type
data IxS sh where