aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs221
1 files changed, 145 insertions, 76 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index d041aff..54b567a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -20,16 +20,51 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# OPTIONS_GHC -Wno-unused-imports #-}
{-|
TODO:
* We should be more consistent in whether functions take a 'StaticShX'
argument or a 'KnownShapeX' constraint.
-* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
- being that we need to do induction over the former, but the latter need to be
- able to get large.
+* Mikolaj wants these:
+
+ About your wishlist of operations: these are already there
+
+ OR.index
+ OR.append
+ OR.transpose
+
+ These can be easily lifted from the definition for XArray (5min work):
+
+ OR.scalar
+ OR.unScalar
+ OR.constant
+
+ These should not be hard:
+
+ OR.fromList
+ ORB.toList . OR.unravel
+ OR.ravel . ORB.fromList
+ OR.slice
+ OR.rev
+ OR.reshape
+
+ though it's a bit unfortunate that we end up needing toList. Looking in
+ horde-ad I see that you seem to need them to do certain operations in Haskell
+ that orthotope doesn't support?
+
+ For this one we'll need to see to what extent you really need it, and what API
+ you'd need precisely:
+
+ OR.rerank
+
+ and for these we have an API design question:
+
+ OR.toVector
+ OR.fromVector
-}
@@ -51,10 +86,10 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..))
import qualified Data.Array.Mixed as X
-import Data.INat
-- Invariant in the API
@@ -91,34 +126,71 @@ import Data.INat
type family Replicate n a where
- Replicate Z a = '[]
- Replicate (S n) a = a : Replicate n a
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
-lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
- go :: SINat m -> StaticShX (Replicate m Nothing)
+ go :: SNat m -> StaticShX (Replicate m Nothing)
go SZ = ZKSX
- go (SS n) = () :!$? go n
+ go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
-lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (inatSing @n)
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (natSing @n)
where
- go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
- go (SS n) | Refl <- go n = Refl
-
-lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a
- -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (inatSing @n)
+ go (SS (n :: SNat nm1))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+ , Refl <- go n
+ = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (natSing @n)
where
- go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
- go (SS n) | Refl <- go n = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- lemReplicateSucc @a @n'm1
+ , Refl <- go n
+ = sym (lemReplicateSucc @a @(n'm1 + m))
shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
shAppSplit _ ZKSX idx = (ZSX, idx)
@@ -575,18 +647,15 @@ deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed s
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'INat'.
+-- represented on the type level as a 'Nat'.
--
-- Valid elements of a ranked arrays are described by the 'Elt' type class.
-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
-- supported (and are represented as a single, flattened, struct-of-arrays
-- array internally).
--
--- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a
--- type-level natural that supports induction.
---
-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: INat -> Type -> Type
+type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
@@ -616,7 +685,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
-instance (Elt a, KnownINat n) => Elt (Ranked n a) where
+instance (Elt a, KnownNat n) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
@@ -848,37 +917,37 @@ rewriteMixed Refl x = x
-- ====== API OF RANKED ARRAYS ====== --
-arithPromoteRanked :: forall n a. KnownINat n
+arithPromoteRanked :: forall n a. KnownNat n
=> (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a
arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
-arithPromoteRanked2 :: forall n a. KnownINat n
+arithPromoteRanked2 :: forall n a. KnownNat n
=> (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a -> Ranked n a
arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
-instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger n = case inatSing @n of
+ fromInteger n = case natSing @n of
SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
- SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
- \Rank non-zero, use explicit mconstant"
+ _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
+ \Rank non-zero, use explicit mconstant"
-- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
-deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double)
type role ListR nominal representational
-type ListR :: INat -> Type -> Type
+type ListR :: Nat -> Type -> Type
data ListR n i where
- ZR :: ListR Z i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i
deriving instance Show i => Show (ListR n i)
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
@@ -892,23 +961,23 @@ listRToList :: ListR n i -> [i]
listRToList ZR = []
listRToList (i ::: is) = i : listRToList is
-knownListR :: ListR n i -> Dict KnownINat n
+knownListR :: ListR n i -> Dict KnownNat n
knownListR ZR = Dict
knownListR (_ ::: l) | Dict <- knownListR l = Dict
-- | An index into a rank-typed array.
type role IxR nominal representational
-type IxR :: INat -> Type -> Type
+type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
-pattern ZIR :: forall n i. () => n ~ Z => IxR n i
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
pattern (:.:)
:: forall {n1} {i}.
- forall n. (S n ~ n1)
+ forall n. (1 + n ~ n1)
=> i -> IxR n i -> IxR n1 i
pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
where i :.: IxR sh = IxR (i ::: sh)
@@ -916,30 +985,30 @@ pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
infixr 3 :.:
data UnconsIxRRes i n1 =
- forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i
+ forall n. (1 + n ~ n1) => UnconsIxRRes (IxR n i) i
unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1)
unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i)
unconsIxR (IxR ZR) = Nothing
type IIxR n = IxR n Int
-knownIxR :: IxR n i -> Dict KnownINat n
+knownIxR :: IxR n i -> Dict KnownNat n
knownIxR (IxR sh) = knownListR sh
type role ShR nominal representational
-type ShR :: INat -> Type -> Type
+type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
type IShR n = ShR n Int
-pattern ZSR :: forall n i. () => n ~ Z => ShR n i
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR
pattern (:$:)
:: forall {n1} {i}.
- forall n. (S n ~ n1)
+ forall n. (1 + n ~ n1)
=> i -> ShR n i -> ShR n1 i
pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
where i :$: (ShR sh) = ShR (i ::: sh)
@@ -947,15 +1016,15 @@ pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
infixr 3 :$:
data UnconsShRRes i n1 =
- forall n. S n ~ n1 => UnconsShRRes (ShR n i) i
+ forall n. 1 + n ~ n1 => UnconsShRRes (ShR n i) i
unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1)
unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i)
unconsShR (ShR ZR) = Nothing
-knownShR :: ShR n i -> Dict KnownINat n
+knownShR :: ShR n i -> Dict KnownNat n
knownShR (ShR sh) = knownListR sh
-zeroIxR :: SINat n -> IIxR n
+zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n
@@ -982,7 +1051,7 @@ shapeSizeR ZSR = 1
shapeSizeR (n :$: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
@@ -991,7 +1060,7 @@ rshape (Ranked arr)
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
@@ -1007,7 +1076,7 @@ rgenerate sh f
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
@@ -1015,39 +1084,39 @@ rlift f (Ranked arr)
= Ranked (mlift f arr)
rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownINat n)
- => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
+ (Storable a, Num a, KnownNat n, 1 <= n)
+ => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked
- . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
- . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
+ . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a))
+ . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing))
+ . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a)
$ arr
rsumOuter1 :: forall n a.
- (Storable a, Num a, PrimElt a, KnownINat n)
- => Ranked (S n) a -> Ranked n a
+ (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n)
+ => Ranked n a -> Ranked (n - 1) a
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
-rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
rtranspose perm (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mtranspose perm arr)
-rappend :: forall n a. (KnownINat n, Elt a)
- => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
+rappend :: forall n a. (KnownNat n, Elt a, 1 <= n)
+ => Ranked n a -> Ranked n a -> Ranked n a
rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
-rscalar :: Elt a => a -> Ranked I0 a
+rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
-rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVectorP (shCvtRX sh) v)
-rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
@@ -1056,39 +1125,39 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
+rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromList1 (coerce l))
-rfromList :: Elt a => NonEmpty a -> Ranked I1 a
+rfromList :: Elt a => NonEmpty a -> Ranked 1 a
rfromList = Ranked . mfromList1 . fmap mscalar
-rtoList :: Elt a => Ranked (S n) a -> [Ranked n a]
+rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a]
rtoList (Ranked arr) = coerce (mtoList1 arr)
-rtoList1 :: Elt a => Ranked I1 a -> [a]
+rtoList1 :: Elt a => Ranked 1 a -> [a]
rtoList1 = map runScalar . rtoList
-runScalar :: Elt a => Ranked I0 a -> a
+runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
-rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstantP (shCvtRX sh) x)
-rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
+rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a)
=> IShR n -> a -> Ranked n a
rconstant sh x = coerce fromPrimitive (rconstantP sh x)
-rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
+rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
rslice ivs = rlift $ \_ -> X.slice ivs
-rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a
+rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a
rrev1 = rlift $ \_ -> X.rev1
-rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a)
+rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
=> IShR n' -> Ranked n a -> Ranked n' a
rreshape sh' (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)