From 6b5139c0a8d0c4e76c349f2847cc5629137f4536 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sat, 18 May 2024 14:05:09 +0200
Subject: Finish singletons refactor?

---
 src/Data/Array/Nested.hs          |  13 +++--
 src/Data/Array/Nested/Internal.hs | 115 ++++++++++++++++----------------------
 2 files changed, 54 insertions(+), 74 deletions(-)

(limited to 'src/Data')

diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 4b455da..45a03d4 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -3,9 +3,9 @@
 module Data.Array.Nested (
   -- * Ranked arrays
   Ranked,
-  ListR(ZR, (:::)), knownListR,
-  IxR(.., ZIR, (:.:)), IIxR, knownIxR,
-  ShR(.., ZSR, (:$:)), knownShR,
+  ListR(ZR, (:::)),
+  IxR(.., ZIR, (:.:)), IIxR,
+  ShR(.., ZSR, (:$:)),
   rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
   rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
   rconstant, rfromList, rfromList1, rtoList, rtoList1,
@@ -19,7 +19,7 @@ module Data.Array.Nested (
   Shaped,
   ListS(ZS, (::$)),
   IxS(.., ZIS, (:.$)), IIxS,
-  ShS(..), KnownShape(..),
+  ShS(.., ZSS, (:$$)), KnownShS(..),
   sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
   stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
   sconstant, sfromList, sfromList1, stoList, stoList1,
@@ -32,7 +32,7 @@ module Data.Array.Nested (
   -- * Mixed arrays
   Mixed,
   IxX(..), IIxX,
-  KnownShapeX(..), StaticShX(..),
+  KnownShX(..), StaticShX(..),
   mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
   mconstant, mfromList, mtoList, mslice, mrev1, mreshape,
   -- ** Conversions
@@ -46,9 +46,9 @@ module Data.Array.Nested (
   -- * Further utilities / re-exports
   type (++),
   Storable,
+  SNat, pattern SNat,
   HList,
   Permutation,
-  makeNatList,
 ) where
 
 import Prelude hiding (mappend)
@@ -56,3 +56,4 @@ import Prelude hiding (mappend)
 import Data.Array.Mixed
 import Data.Array.Nested.Internal
 import Foreign.Storable
+import GHC.TypeLits
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 7d98975..a3c2b6d 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1390,28 +1390,19 @@ sindexPartial sarr@(Shaped arr) idx =
 sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
 sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
 
-{-
 -- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
-      => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+slift :: forall sh1 sh2 a. Elt a
+      => ShS sh2
+      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
       -> Shaped sh1 a -> Shaped sh2 a
-slift f (Shaped arr)
-  | Dict <- lemKnownMapJust (Proxy @sh2)
-  = Shaped (mlift f arr)
+slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr)
 
-ssumOuter1P :: forall sh n a.
-               (Storable a, Num a, KnownNat n, KnownShape sh)
+ssumOuter1P :: forall sh n a. (Storable a, Num a)
             => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr)
-  | Dict <- lemKnownMapJust (Proxy @sh)
-  = Shaped
-    . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
-    . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh))
-    . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
-    $ arr
-
-ssumOuter1 :: forall sh n a.
-              (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh)
+ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) =
+  Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr))
+
+ssumOuter1 :: forall sh n a. (Storable a, Num a, PrimElt a)
            => Shaped (n : sh) a -> Shaped sh a
 ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
 
@@ -1457,30 +1448,27 @@ shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
   = shIndex p pT i sh rest
 shIndex _ _ _ ZSS _ = error "Index into empty shape"
 
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
-stranspose perm (Shaped arr)
-  | Dict <- lemKnownMapJust (Proxy @sh)
-  , Refl <- lemRankMapJust (Proxy @sh)
-  , Refl <- lemCommMapJustTakeLen perm (knownShape @sh)
-  , Refl <- lemCommMapJustDropLen perm (knownShape @sh)
-  , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh))
-  , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh))
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
+           => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+  | Refl <- lemRankMapJust (sshape sarr)
+  , Refl <- lemCommMapJustTakeLen perm (sshape sarr)
+  , Refl <- lemCommMapJustDropLen perm (sshape sarr)
+  , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr))
+  , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
   = Shaped (mtranspose perm arr)
 
-sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)
-        => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
 
 sscalar :: Elt a => a -> Shaped '[] a
 sscalar x = Shaped (mscalar x)
 
-sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP v
-  | Dict <- lemKnownMapJust (Proxy @sh)
-  = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v)
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
 
-sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a
-sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v)
+sfromVector :: (Storable a, PrimElt a) => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = coerce fromPrimitive (sfromVectorP sh v)
 
 stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
 stoVectorP = coerce mtoVectorP
@@ -1488,14 +1476,11 @@ stoVectorP = coerce mtoVectorP
 stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a
 stoVector = coerce mtoVector
 
-sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
-           => NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromList1 l
-  | Dict <- lemKnownMapJust (Proxy @sh)
-  = Shaped (mfromList1 (coerce l))
+sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l))
 
-sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a
-sfromList = Shaped . mfromList1 . fmap mscalar
+sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList
 
 stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
 stoList (Shaped arr) = coerce (mtoList1 arr)
@@ -1506,37 +1491,31 @@ stoList1 = map sunScalar . stoList
 sunScalar :: Elt a => Shaped '[] a -> a
 sunScalar arr = sindex arr ZIS
 
-sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)
-sconstantP x
-  | Dict <- lemKnownMapJust (Proxy @sh)
-  = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)
+sconstantP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sconstantP sh x = Shaped (mconstantP (shCvtSX sh) x)
 
-sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
-          => a -> Shaped sh a
-sconstant x = coerce fromPrimitive (sconstantP @sh x)
+sconstant :: (Storable a, PrimElt a) => ShS sh -> a -> Shaped sh a
+sconstant sh x = coerce fromPrimitive (sconstantP sh x)
 
-sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n@SNat = slift $ \_ -> X.slice i n
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n@SNat arr =
+  let _ :$$ sh = sshape arr
+  in slift (n :$$ sh) (\_ -> X.slice i n) arr
 
-srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 = slift $ \_ -> X.rev1
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
 
-sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a)
-         => ShS sh' -> Shaped sh a -> Shaped sh' a
-sreshape sh' (Shaped arr)
-  | Dict <- lemKnownMapJust (Proxy @sh)
-  , Dict <- lemKnownMapJust (Proxy @sh')
-  = Shaped (mreshape (shCvtSX sh') arr)
+sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
 
-sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a
-sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr
+sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)
 
-sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a
-sasXArrayPrim (Shaped arr) = masXArrayPrim arr
+sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
 
-sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP = Shaped . mfromXArrayPrimP
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr)
 
-sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim = Shaped . mfromXArrayPrim
--}
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr)
-- 
cgit v1.2.3-70-g09d2