aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-14 23:30:53 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-14 23:30:53 +0200
commit43ddff2e7f1e9f4d8855f573384e26b63d34f697 (patch)
tree86b6989d4a1b935fd6f6338b4699d8e5c0083a2c /src/Data/Array
parent77ab86ede90938fa43f7f9988ac10f7026440a1c (diff)
WIP GHC nats
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs93
-rw-r--r--src/Data/Array/Nested/Internal.hs221
2 files changed, 184 insertions, 130 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index d782e9f..672b832 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -13,6 +13,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -27,8 +28,10 @@ import Foreign.Storable (Storable)
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
-import Data.INat
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
-- | The 'SNat' pattern synonym is complete, but it doesn't have a
-- @COMPLETE@ pragma. This copy of it does.
@@ -103,11 +106,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :!$? knownShapeX
type family Rank sh where
- Rank '[] = Z
- Rank (_ : sh) = S (Rank sh)
+ Rank '[] = 0
+ Rank (_ : sh) = 1 + Rank sh
type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
+newtype XArray sh a = XArray (S.Array (Rank sh) a)
deriving (Show)
zeroIxX :: StaticShX sh -> IIxX sh
@@ -217,22 +220,22 @@ staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
lemRankApp :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZSX = Dict
-lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict
-lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZKSX = Dict
-lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKSX = Dict
+lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
lemKnownShapeX ZKSX = Dict
@@ -259,8 +262,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.fromVector (shapeLshape sh) v)
toVector :: Storable a => XArray sh a -> VS.Vector a
@@ -274,15 +276,14 @@ unScalar (XArray a) = S.unScalar a
constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
constant sh x
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh) x)
generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownINatRank sh =
+-- generateM sh f | Dict <- lemKnownNatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
@@ -305,8 +306,7 @@ type family AddMaybe n m where
append :: forall n m sh a. (KnownShapeX sh, Storable a)
=> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append (XArray a) (XArray b)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
@@ -315,15 +315,12 @@ rerank :: forall sh sh1 sh2 a b.
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
+ = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a -> unXArray (f (XArray a)))
arr)
where
@@ -342,15 +339,12 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
@@ -359,8 +353,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
-- | The list argument gives indices into the original dimension list.
transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
transpose perm (XArray arr)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
@@ -369,10 +362,8 @@ transpose2 :: forall sh1 sh2 a.
transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
, Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2)))
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
@@ -395,13 +386,12 @@ sumOuter ssh ssh'
fromList1 :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromList1 ssh l
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
+ | Dict <- lemKnownNatRankSSX ssh
= case ssh of
m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
+ _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
@@ -409,8 +399,7 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
empty sh
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh)
(error "Data.Array.Mixed.empty: shape was not empty"))
@@ -423,17 +412,13 @@ rev1 (XArray arr) = XArray (S.rev [0] arr)
-- | Throws if the given array and the target shape do not have the same number of elements.
reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
reshape ssh1 sh2 (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh1
- , Dict <- knownNatFromINat (Proxy @(Rank sh1))
- , Dict <- lemKnownINatRank sh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh1
+ , Dict <- lemKnownNatRank sh2
= XArray (S.reshape (shapeLshape sh2) arr)
-- | Throws if the given array and the target shape do not have the same number of elements.
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
- | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh')))
- , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh')))
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
= XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index d041aff..54b567a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -20,16 +20,51 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# OPTIONS_GHC -Wno-unused-imports #-}
{-|
TODO:
* We should be more consistent in whether functions take a 'StaticShX'
argument or a 'KnownShapeX' constraint.
-* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
- being that we need to do induction over the former, but the latter need to be
- able to get large.
+* Mikolaj wants these:
+
+ About your wishlist of operations: these are already there
+
+ OR.index
+ OR.append
+ OR.transpose
+
+ These can be easily lifted from the definition for XArray (5min work):
+
+ OR.scalar
+ OR.unScalar
+ OR.constant
+
+ These should not be hard:
+
+ OR.fromList
+ ORB.toList . OR.unravel
+ OR.ravel . ORB.fromList
+ OR.slice
+ OR.rev
+ OR.reshape
+
+ though it's a bit unfortunate that we end up needing toList. Looking in
+ horde-ad I see that you seem to need them to do certain operations in Haskell
+ that orthotope doesn't support?
+
+ For this one we'll need to see to what extent you really need it, and what API
+ you'd need precisely:
+
+ OR.rerank
+
+ and for these we have an API design question:
+
+ OR.toVector
+ OR.fromVector
-}
@@ -51,10 +86,10 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..))
import qualified Data.Array.Mixed as X
-import Data.INat
-- Invariant in the API
@@ -91,34 +126,71 @@ import Data.INat
type family Replicate n a where
- Replicate Z a = '[]
- Replicate (S n) a = a : Replicate n a
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
-lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
- go :: SINat m -> StaticShX (Replicate m Nothing)
+ go :: SNat m -> StaticShX (Replicate m Nothing)
go SZ = ZKSX
- go (SS n) = () :!$? go n
+ go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
-lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (inatSing @n)
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (natSing @n)
where
- go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
- go (SS n) | Refl <- go n = Refl
-
-lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a
- -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (inatSing @n)
+ go (SS (n :: SNat nm1))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+ , Refl <- go n
+ = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (natSing @n)
where
- go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
- go (SS n) | Refl <- go n = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- lemReplicateSucc @a @n'm1
+ , Refl <- go n
+ = sym (lemReplicateSucc @a @(n'm1 + m))
shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
shAppSplit _ ZKSX idx = (ZSX, idx)
@@ -575,18 +647,15 @@ deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed s
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'INat'.
+-- represented on the type level as a 'Nat'.
--
-- Valid elements of a ranked arrays are described by the 'Elt' type class.
-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
-- supported (and are represented as a single, flattened, struct-of-arrays
-- array internally).
--
--- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a
--- type-level natural that supports induction.
---
-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: INat -> Type -> Type
+type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
@@ -616,7 +685,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
-instance (Elt a, KnownINat n) => Elt (Ranked n a) where
+instance (Elt a, KnownNat n) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
@@ -848,37 +917,37 @@ rewriteMixed Refl x = x
-- ====== API OF RANKED ARRAYS ====== --
-arithPromoteRanked :: forall n a. KnownINat n
+arithPromoteRanked :: forall n a. KnownNat n
=> (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a
arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
-arithPromoteRanked2 :: forall n a. KnownINat n
+arithPromoteRanked2 :: forall n a. KnownNat n
=> (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a -> Ranked n a
arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
-instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger n = case inatSing @n of
+ fromInteger n = case natSing @n of
SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
- SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
- \Rank non-zero, use explicit mconstant"
+ _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
+ \Rank non-zero, use explicit mconstant"
-- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
-deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double)
type role ListR nominal representational
-type ListR :: INat -> Type -> Type
+type ListR :: Nat -> Type -> Type
data ListR n i where
- ZR :: ListR Z i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i
deriving instance Show i => Show (ListR n i)
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
@@ -892,23 +961,23 @@ listRToList :: ListR n i -> [i]
listRToList ZR = []
listRToList (i ::: is) = i : listRToList is
-knownListR :: ListR n i -> Dict KnownINat n
+knownListR :: ListR n i -> Dict KnownNat n
knownListR ZR = Dict
knownListR (_ ::: l) | Dict <- knownListR l = Dict
-- | An index into a rank-typed array.
type role IxR nominal representational
-type IxR :: INat -> Type -> Type
+type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
-pattern ZIR :: forall n i. () => n ~ Z => IxR n i
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
pattern (:.:)
:: forall {n1} {i}.
- forall n. (S n ~ n1)
+ forall n. (1 + n ~ n1)
=> i -> IxR n i -> IxR n1 i
pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
where i :.: IxR sh = IxR (i ::: sh)
@@ -916,30 +985,30 @@ pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
infixr 3 :.:
data UnconsIxRRes i n1 =
- forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i
+ forall n. (1 + n ~ n1) => UnconsIxRRes (IxR n i) i
unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1)
unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i)
unconsIxR (IxR ZR) = Nothing
type IIxR n = IxR n Int
-knownIxR :: IxR n i -> Dict KnownINat n
+knownIxR :: IxR n i -> Dict KnownNat n
knownIxR (IxR sh) = knownListR sh
type role ShR nominal representational
-type ShR :: INat -> Type -> Type
+type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
type IShR n = ShR n Int
-pattern ZSR :: forall n i. () => n ~ Z => ShR n i
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR
pattern (:$:)
:: forall {n1} {i}.
- forall n. (S n ~ n1)
+ forall n. (1 + n ~ n1)
=> i -> ShR n i -> ShR n1 i
pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
where i :$: (ShR sh) = ShR (i ::: sh)
@@ -947,15 +1016,15 @@ pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
infixr 3 :$:
data UnconsShRRes i n1 =
- forall n. S n ~ n1 => UnconsShRRes (ShR n i) i
+ forall n. 1 + n ~ n1 => UnconsShRRes (ShR n i) i
unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1)
unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i)
unconsShR (ShR ZR) = Nothing
-knownShR :: ShR n i -> Dict KnownINat n
+knownShR :: ShR n i -> Dict KnownNat n
knownShR (ShR sh) = knownListR sh
-zeroIxR :: SINat n -> IIxR n
+zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n
@@ -982,7 +1051,7 @@ shapeSizeR ZSR = 1
shapeSizeR (n :$: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
@@ -991,7 +1060,7 @@ rshape (Ranked arr)
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
@@ -1007,7 +1076,7 @@ rgenerate sh f
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
@@ -1015,39 +1084,39 @@ rlift f (Ranked arr)
= Ranked (mlift f arr)
rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownINat n)
- => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
+ (Storable a, Num a, KnownNat n, 1 <= n)
+ => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked
- . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
- . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
+ . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a))
+ . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing))
+ . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a)
$ arr
rsumOuter1 :: forall n a.
- (Storable a, Num a, PrimElt a, KnownINat n)
- => Ranked (S n) a -> Ranked n a
+ (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n)
+ => Ranked n a -> Ranked (n - 1) a
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
-rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mtranspose perm arr)
-rappend :: forall n a. (KnownINat n, Elt a)
- => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
+rappend :: forall n a. (KnownNat n, Elt a, 1 <= n)
+ => Ranked n a -> Ranked n a -> Ranked n a
rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
-rscalar :: Elt a => a -> Ranked I0 a
+rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
-rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVectorP (shCvtRX sh) v)
-rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
@@ -1056,39 +1125,39 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
+rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromList1 (coerce l))
-rfromList :: Elt a => NonEmpty a -> Ranked I1 a
+rfromList :: Elt a => NonEmpty a -> Ranked 1 a
rfromList = Ranked . mfromList1 . fmap mscalar
-rtoList :: Elt a => Ranked (S n) a -> [Ranked n a]
+rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a]
rtoList (Ranked arr) = coerce (mtoList1 arr)
-rtoList1 :: Elt a => Ranked I1 a -> [a]
+rtoList1 :: Elt a => Ranked 1 a -> [a]
rtoList1 = map runScalar . rtoList
-runScalar :: Elt a => Ranked I0 a -> a
+runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
-rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstantP (shCvtRX sh) x)
-rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
+rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a)
=> IShR n -> a -> Ranked n a
rconstant sh x = coerce fromPrimitive (rconstantP sh x)
-rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
+rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
rslice ivs = rlift $ \_ -> X.slice ivs
-rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a
+rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a
rrev1 = rlift $ \_ -> X.rev1
-rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a)
+rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
=> IShR n' -> Ranked n a -> Ranked n' a
rreshape sh' (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)