From 43ddff2e7f1e9f4d8855f573384e26b63d34f697 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 14 May 2024 23:30:53 +0200
Subject: WIP GHC nats

---
 src/Data/Array/Nested/Internal.hs | 221 +++++++++++++++++++++++++-------------
 1 file changed, 145 insertions(+), 76 deletions(-)

(limited to 'src/Data/Array/Nested')

diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index d041aff..54b567a 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -20,16 +20,51 @@
 {-# LANGUAGE TypeOperators #-}
 {-# LANGUAGE UndecidableInstances #-}
 {-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
 {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# OPTIONS_GHC -Wno-unused-imports #-}
 
 {-|
 TODO:
 * We should be more consistent in whether functions take a 'StaticShX'
   argument or a 'KnownShapeX' constraint.
 
-* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
-  being that we need to do induction over the former, but the latter need to be
-  able to get large.
+* Mikolaj wants these:
+
+    About your wishlist of operations: these are already there
+
+      OR.index
+      OR.append
+      OR.transpose
+
+    These can be easily lifted from the definition for XArray (5min work):
+
+      OR.scalar
+      OR.unScalar
+      OR.constant
+
+    These should not be hard:
+
+      OR.fromList
+      ORB.toList . OR.unravel
+      OR.ravel . ORB.fromList
+      OR.slice
+      OR.rev
+      OR.reshape
+
+    though it's a bit unfortunate that we end up needing toList. Looking in
+    horde-ad I see that you seem to need them to do certain operations in Haskell
+    that orthotope doesn't support?
+
+    For this one we'll need to see to what extent you really need it, and what API
+    you'd need precisely:
+
+      OR.rerank
+
+    and for these we have an API design question:
+
+      OR.toVector
+      OR.fromVector
 
 -}
 
@@ -51,10 +86,10 @@ import qualified Data.Vector.Storable as VS
 import qualified Data.Vector.Storable.Mutable as VSM
 import Foreign.Storable (Storable)
 import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
 
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..))
 import qualified Data.Array.Mixed as X
-import Data.INat
 
 
 -- Invariant in the API
@@ -91,34 +126,71 @@ import Data.INat
 
 
 type family Replicate n a where
-  Replicate Z a = '[]
-  Replicate (S n) a = a : Replicate n a
+  Replicate 0 a = '[]
+  Replicate n a = a : Replicate (n - 1) a
 
 type family MapJust l where
   MapJust '[] = '[]
   MapJust (x : xs) = Just x : MapJust xs
 
-lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+  where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+  where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+  withKnownNat snp1 $
+    case cmpNat (Proxy @1) (Proxy @np1) of
+      LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+      EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+      GTI -> Nothing
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
   where
-    go :: SINat m -> StaticShX (Replicate m Nothing)
+    go :: SNat m -> StaticShX (Replicate m Nothing)
     go SZ = ZKSX
-    go (SS n) = () :!$? go n
+    go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
 
-lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (inatSing @n)
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (natSing @n)
   where
-    go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+    go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
     go SZ = Refl
-    go (SS n) | Refl <- go n = Refl
-
-lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a
-                    -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (inatSing @n)
+    go (SS (n :: SNat nm1))
+      | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+      , Refl <- go n
+      = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (natSing @n)
   where
-    go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a
+    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
     go SZ = Refl
-    go (SS n) | Refl <- go n = Refl
+    go (SS (n :: SNat n'm1))
+      | Refl <- lemReplicateSucc @a @n'm1
+      , Refl <- go n
+      = sym (lemReplicateSucc @a @(n'm1 + m))
 
 shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
 shAppSplit _ ZKSX idx = (ZSX, idx)
@@ -575,18 +647,15 @@ deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed s
 
 
 -- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'INat'.
+-- represented on the type level as a 'Nat'.
 --
 -- Valid elements of a ranked arrays are described by the 'Elt' type class.
 -- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
 -- supported (and are represented as a single, flattened, struct-of-arrays
 -- array internally).
 --
--- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a
--- type-level natural that supports induction.
---
 -- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: INat -> Type -> Type
+type Ranked :: Nat -> Type -> Type
 newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
 deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
 
@@ -616,7 +685,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
 -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
 -- these instances allow them to also be used as elements of arrays, thus
 -- making them first-class in the API.
-instance (Elt a, KnownINat n) => Elt (Ranked n a) where
+instance (Elt a, KnownNat n) => Elt (Ranked n a) where
   mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
   mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
 
@@ -848,37 +917,37 @@ rewriteMixed Refl x = x
 
 -- ====== API OF RANKED ARRAYS ====== --
 
-arithPromoteRanked :: forall n a. KnownINat n
+arithPromoteRanked :: forall n a. KnownNat n
                    => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
                    -> Ranked n a -> Ranked n a
 arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
 
-arithPromoteRanked2 :: forall n a. KnownINat n
+arithPromoteRanked2 :: forall n a. KnownNat n
                     => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
                     -> Ranked n a -> Ranked n a -> Ranked n a
 arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
 
-instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
   (+) = arithPromoteRanked2 (+)
   (-) = arithPromoteRanked2 (-)
   (*) = arithPromoteRanked2 (*)
   negate = arithPromoteRanked negate
   abs = arithPromoteRanked abs
   signum = arithPromoteRanked signum
-  fromInteger n = case inatSing @n of
+  fromInteger n = case natSing @n of
                     SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
-                    SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
-                                  \Rank non-zero, use explicit mconstant"
+                    _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
+                               \Rank non-zero, use explicit mconstant"
 
 -- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
-deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double)
 
 type role ListR nominal representational
-type ListR :: INat -> Type -> Type
+type ListR :: Nat -> Type -> Type
 data ListR n i where
-  ZR :: ListR Z i
-  (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i
+  ZR :: ListR 0 i
+  (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i
 deriving instance Show i => Show (ListR n i)
 deriving instance Eq i => Eq (ListR n i)
 deriving instance Ord i => Ord (ListR n i)
@@ -892,23 +961,23 @@ listRToList :: ListR n i -> [i]
 listRToList ZR = []
 listRToList (i ::: is) = i : listRToList is
 
-knownListR :: ListR n i -> Dict KnownINat n
+knownListR :: ListR n i -> Dict KnownNat n
 knownListR ZR = Dict
 knownListR (_ ::: l) | Dict <- knownListR l = Dict
 
 -- | An index into a rank-typed array.
 type role IxR nominal representational
-type IxR :: INat -> Type -> Type
+type IxR :: Nat -> Type -> Type
 newtype IxR n i = IxR (ListR n i)
   deriving (Show, Eq, Ord)
   deriving newtype (Functor, Foldable)
 
-pattern ZIR :: forall n i. () => n ~ Z => IxR n i
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
 pattern ZIR = IxR ZR
 
 pattern (:.:)
   :: forall {n1} {i}.
-     forall n. (S n ~ n1)
+     forall n. (1 + n ~ n1)
   => i -> IxR n i -> IxR n1 i
 pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
   where i :.: IxR sh = IxR (i ::: sh)
@@ -916,30 +985,30 @@ pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
 infixr 3 :.:
 
 data UnconsIxRRes i n1 =
-  forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i
+  forall n. (1 + n ~ n1) => UnconsIxRRes (IxR n i) i
 unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1)
 unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i)
 unconsIxR (IxR ZR) = Nothing
 
 type IIxR n = IxR n Int
 
-knownIxR :: IxR n i -> Dict KnownINat n
+knownIxR :: IxR n i -> Dict KnownNat n
 knownIxR (IxR sh) = knownListR sh
 
 type role ShR nominal representational
-type ShR :: INat -> Type -> Type
+type ShR :: Nat -> Type -> Type
 newtype ShR n i = ShR (ListR n i)
   deriving (Show, Eq, Ord)
   deriving newtype (Functor, Foldable)
 
 type IShR n = ShR n Int
 
-pattern ZSR :: forall n i. () => n ~ Z => ShR n i
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
 pattern ZSR = ShR ZR
 
 pattern (:$:)
   :: forall {n1} {i}.
-     forall n. (S n ~ n1)
+     forall n. (1 + n ~ n1)
   => i -> ShR n i -> ShR n1 i
 pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
   where i :$: (ShR sh) = ShR (i ::: sh)
@@ -947,15 +1016,15 @@ pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
 infixr 3 :$:
 
 data UnconsShRRes i n1 =
-  forall n. S n ~ n1 => UnconsShRRes (ShR n i) i
+  forall n. 1 + n ~ n1 => UnconsShRRes (ShR n i) i
 unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1)
 unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i)
 unconsShR (ShR ZR) = Nothing
 
-knownShR :: ShR n i -> Dict KnownINat n
+knownShR :: ShR n i -> Dict KnownNat n
 knownShR (ShR sh) = knownListR sh
 
-zeroIxR :: SINat n -> IIxR n
+zeroIxR :: SNat n -> IIxR n
 zeroIxR SZ = ZIR
 zeroIxR (SS n) = 0 :.: zeroIxR n
 
@@ -982,7 +1051,7 @@ shapeSizeR ZSR = 1
 shapeSizeR (n :$: sh) = n * shapeSizeR sh
 
 
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n
 rshape (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n)
   , Refl <- lemRankReplicate (Proxy @n)
@@ -991,7 +1060,7 @@ rshape (Ranked arr)
 rindex :: Elt a => Ranked n a -> IIxR n -> a
 rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
 
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a
 rindexPartial (Ranked arr) idx =
   Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
             (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
@@ -1007,7 +1076,7 @@ rgenerate sh f
   = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
 
 -- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
       => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
       -> Ranked n1 a -> Ranked n2 a
 rlift f (Ranked arr)
@@ -1015,39 +1084,39 @@ rlift f (Ranked arr)
   = Ranked (mlift f arr)
 
 rsumOuter1P :: forall n a.
-               (Storable a, Num a, KnownINat n)
-            => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
+               (Storable a, Num a, KnownNat n, 1 <= n)
+            => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a)
 rsumOuter1P (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked
-    . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
-    . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
-    . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
+    . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a))
+    . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing))
+    . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a)
     $ arr
 
 rsumOuter1 :: forall n a.
-              (Storable a, Num a, PrimElt a, KnownINat n)
-           => Ranked (S n) a -> Ranked n a
+  (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n)
+           => Ranked n a -> Ranked (n - 1) a
 rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
 
-rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
 rtranspose perm (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked (mtranspose perm arr)
 
-rappend :: forall n a. (KnownINat n, Elt a)
-        => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
+rappend :: forall n a. (KnownNat n, Elt a, 1 <= n)
+        => Ranked n a -> Ranked n a -> Ranked n a
 rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
 
-rscalar :: Elt a => a -> Ranked I0 a
+rscalar :: Elt a => a -> Ranked 0 a
 rscalar x = Ranked (mscalar x)
 
-rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
 rfromVectorP sh v
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked (mfromVectorP (shCvtRX sh) v)
 
-rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
 rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
 
 rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
@@ -1056,39 +1125,39 @@ rtoVectorP = coerce mtoVectorP
 rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
 rtoVector = coerce mtoVector
 
-rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
+rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a
 rfromList1 l
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked (mfromList1 (coerce l))
 
-rfromList :: Elt a => NonEmpty a -> Ranked I1 a
+rfromList :: Elt a => NonEmpty a -> Ranked 1 a
 rfromList = Ranked . mfromList1 . fmap mscalar
 
-rtoList :: Elt a => Ranked (S n) a -> [Ranked n a]
+rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a]
 rtoList (Ranked arr) = coerce (mtoList1 arr)
 
-rtoList1 :: Elt a => Ranked I1 a -> [a]
+rtoList1 :: Elt a => Ranked 1 a -> [a]
 rtoList1 = map runScalar . rtoList
 
-runScalar :: Elt a => Ranked I0 a -> a
+runScalar :: Elt a => Ranked 0 a -> a
 runScalar arr = rindex arr ZIR
 
-rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
 rconstantP sh x
   | Dict <- lemKnownReplicate (Proxy @n)
   = Ranked (mconstantP (shCvtRX sh) x)
 
-rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
+rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a)
           => IShR n -> a -> Ranked n a
 rconstant sh x = coerce fromPrimitive (rconstantP sh x)
 
-rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
+rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
 rslice ivs = rlift $ \_ -> X.slice ivs
 
-rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a
+rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a
 rrev1 = rlift $ \_ -> X.rev1
 
-rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a)
+rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
          => IShR n' -> Ranked n a -> Ranked n' a
 rreshape sh' (Ranked arr)
   | Dict <- lemKnownReplicate (Proxy @n)
-- 
cgit v1.2.3-70-g09d2