From 6b5139c0a8d0c4e76c349f2847cc5629137f4536 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 18 May 2024 14:05:09 +0200 Subject: Finish singletons refactor? --- src/Data/Array/Nested.hs | 13 +++-- src/Data/Array/Nested/Internal.hs | 115 ++++++++++++++++---------------------- test/Main.hs | 2 +- 3 files changed, 55 insertions(+), 75 deletions(-) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 4b455da..45a03d4 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -3,9 +3,9 @@ module Data.Array.Nested ( -- * Ranked arrays Ranked, - ListR(ZR, (:::)), knownListR, - IxR(.., ZIR, (:.:)), IIxR, knownIxR, - ShR(.., ZSR, (:$:)), knownShR, + ListR(ZR, (:::)), + IxR(.., ZIR, (:.:)), IIxR, + ShR(.., ZSR, (:$:)), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, rconstant, rfromList, rfromList1, rtoList, rtoList1, @@ -19,7 +19,7 @@ module Data.Array.Nested ( Shaped, ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, - ShS(..), KnownShape(..), + ShS(.., ZSS, (:$$)), KnownShS(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, sconstant, sfromList, sfromList1, stoList, stoList1, @@ -32,7 +32,7 @@ module Data.Array.Nested ( -- * Mixed arrays Mixed, IxX(..), IIxX, - KnownShapeX(..), StaticShX(..), + KnownShX(..), StaticShX(..), mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, mconstant, mfromList, mtoList, mslice, mrev1, mreshape, -- ** Conversions @@ -46,9 +46,9 @@ module Data.Array.Nested ( -- * Further utilities / re-exports type (++), Storable, + SNat, pattern SNat, HList, Permutation, - makeNatList, ) where import Prelude hiding (mappend) @@ -56,3 +56,4 @@ import Prelude hiding (mappend) import Data.Array.Mixed import Data.Array.Nested.Internal import Foreign.Storable +import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 7d98975..a3c2b6d 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1390,28 +1390,19 @@ sindexPartial sarr@(Shaped arr) idx = sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) -{- -- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -slift f (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh2) - = Shaped (mlift f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr) -ssumOuter1P :: forall sh n a. - (Storable a, Num a, KnownNat n, KnownShape sh) +ssumOuter1P :: forall sh n a. (Storable a, Num a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped - . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) - . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh)) - . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) - $ arr - -ssumOuter1 :: forall sh n a. - (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh) +ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) = + Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr)) + +ssumOuter1 :: forall sh n a. (Storable a, Num a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive @@ -1457,30 +1448,27 @@ shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest = shIndex p pT i sh rest shIndex _ _ _ ZSS _ = error "Index into empty shape" -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a -stranspose perm (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - , Refl <- lemRankMapJust (Proxy @sh) - , Refl <- lemCommMapJustTakeLen perm (knownShape @sh) - , Refl <- lemCommMapJustDropLen perm (knownShape @sh) - , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh)) - , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh)) +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) + => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemCommMapJustTakeLen perm (sshape sarr) + , Refl <- lemCommMapJustDropLen perm (sshape sarr) + , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr)) + , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) = Shaped (mtranspose perm arr) -sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) - => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) -sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP v - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v) +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) -sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a -sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v) +sfromVector :: (Storable a, PrimElt a) => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = coerce fromPrimitive (sfromVectorP sh v) stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP @@ -1488,14 +1476,11 @@ stoVectorP = coerce mtoVectorP stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector -sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) - => NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 l - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromList1 (coerce l)) +sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l)) -sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a -sfromList = Shaped . mfromList1 . fmap mscalar +sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a] stoList (Shaped arr) = coerce (mtoList1 arr) @@ -1506,37 +1491,31 @@ stoList1 = map sunScalar . stoList sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS -sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) -sconstantP x - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) +sconstantP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sconstantP sh x = Shaped (mconstantP (shCvtSX sh) x) -sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) - => a -> Shaped sh a -sconstant x = coerce fromPrimitive (sconstantP @sh x) +sconstant :: (Storable a, PrimElt a) => ShS sh -> a -> Shaped sh a +sconstant sh x = coerce fromPrimitive (sconstantP sh x) -sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat = slift $ \_ -> X.slice i n +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr -srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 = slift $ \_ -> X.rev1 +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr -sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a) - => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - , Dict <- lemKnownMapJust (Proxy @sh') - = Shaped (mreshape (shCvtSX sh') arr) +sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) -sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a -sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr +sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) -sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a -sasXArrayPrim (Shaped arr) = masXArrayPrim arr +sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) -sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP = Shaped . mfromXArrayPrimP +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr) -sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim = Shaped . mfromXArrayPrim --} +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) diff --git a/test/Main.hs b/test/Main.hs index 783d985..76c75c2 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,7 +8,7 @@ import Data.Array.Nested arr :: Ranked 2 (Shaped [2, 3] (Double, Int)) arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> - sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) -> + sgenerate (SNat @2 :$$ SNat @3 :$$ ZSS) $ \(k :.$ l :.$ ZIS) -> let s = 24*i + 6*j + 3*k + l in (fromIntegral s, s) -- cgit v1.2.3-70-g09d2