diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested.hs | 13 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 113 | 
2 files changed, 53 insertions, 73 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 4b455da..45a03d4 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -3,9 +3,9 @@  module Data.Array.Nested (    -- * Ranked arrays    Ranked, -  ListR(ZR, (:::)), knownListR, -  IxR(.., ZIR, (:.:)), IIxR, knownIxR, -  ShR(.., ZSR, (:$:)), knownShR, +  ListR(ZR, (:::)), +  IxR(.., ZIR, (:.:)), IIxR, +  ShR(.., ZSR, (:$:)),    rshape, rindex, rindexPartial, rgenerate, rsumOuter1,    rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,    rconstant, rfromList, rfromList1, rtoList, rtoList1, @@ -19,7 +19,7 @@ module Data.Array.Nested (    Shaped,    ListS(ZS, (::$)),    IxS(.., ZIS, (:.$)), IIxS, -  ShS(..), KnownShape(..), +  ShS(.., ZSS, (:$$)), KnownShS(..),    sshape, sindex, sindexPartial, sgenerate, ssumOuter1,    stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,    sconstant, sfromList, sfromList1, stoList, stoList1, @@ -32,7 +32,7 @@ module Data.Array.Nested (    -- * Mixed arrays    Mixed,    IxX(..), IIxX, -  KnownShapeX(..), StaticShX(..), +  KnownShX(..), StaticShX(..),    mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,    mconstant, mfromList, mtoList, mslice, mrev1, mreshape,    -- ** Conversions @@ -46,9 +46,9 @@ module Data.Array.Nested (    -- * Further utilities / re-exports    type (++),    Storable, +  SNat, pattern SNat,    HList,    Permutation, -  makeNatList,  ) where  import Prelude hiding (mappend) @@ -56,3 +56,4 @@ import Prelude hiding (mappend)  import Data.Array.Mixed  import Data.Array.Nested.Internal  import Foreign.Storable +import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 7d98975..a3c2b6d 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1390,28 +1390,19 @@ sindexPartial sarr@(Shaped arr) idx =  sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a  sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) -{-  -- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) -      => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) +slift :: forall sh1 sh2 a. Elt a +      => ShS sh2 +      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)        -> Shaped sh1 a -> Shaped sh2 a -slift f (Shaped arr) -  | Dict <- lemKnownMapJust (Proxy @sh2) -  = Shaped (mlift f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr) -ssumOuter1P :: forall sh n a. -               (Storable a, Num a, KnownNat n, KnownShape sh) +ssumOuter1P :: forall sh n a. (Storable a, Num a)              => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) -  | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped -    . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) -    . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh)) -    . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) -    $ arr +ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) = +  Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr)) -ssumOuter1 :: forall sh n a. -              (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh) +ssumOuter1 :: forall sh n a. (Storable a, Num a, PrimElt a)             => Shaped (n : sh) a -> Shaped sh a  ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive @@ -1457,30 +1448,27 @@ shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest    = shIndex p pT i sh rest  shIndex _ _ _ ZSS _ = error "Index into empty shape" -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a -stranspose perm (Shaped arr) -  | Dict <- lemKnownMapJust (Proxy @sh) -  , Refl <- lemRankMapJust (Proxy @sh) -  , Refl <- lemCommMapJustTakeLen perm (knownShape @sh) -  , Refl <- lemCommMapJustDropLen perm (knownShape @sh) -  , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh)) -  , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh)) +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) +           => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) +  | Refl <- lemRankMapJust (sshape sarr) +  , Refl <- lemCommMapJustTakeLen perm (sshape sarr) +  , Refl <- lemCommMapJustDropLen perm (sshape sarr) +  , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr)) +  , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))    = Shaped (mtranspose perm arr) -sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) -        => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend  sscalar :: Elt a => a -> Shaped '[] a  sscalar x = Shaped (mscalar x) -sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP v -  | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v) +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) -sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a -sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v) +sfromVector :: (Storable a, PrimElt a) => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = coerce fromPrimitive (sfromVectorP sh v)  stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a  stoVectorP = coerce mtoVectorP @@ -1488,14 +1476,11 @@ stoVectorP = coerce mtoVectorP  stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a  stoVector = coerce mtoVector -sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) -           => NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 l -  | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mfromList1 (coerce l)) +sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l)) -sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a -sfromList = Shaped . mfromList1 . fmap mscalar +sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList  stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]  stoList (Shaped arr) = coerce (mtoList1 arr) @@ -1506,37 +1491,31 @@ stoList1 = map sunScalar . stoList  sunScalar :: Elt a => Shaped '[] a -> a  sunScalar arr = sindex arr ZIS -sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) -sconstantP x -  | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) +sconstantP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sconstantP sh x = Shaped (mconstantP (shCvtSX sh) x) -sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) -          => a -> Shaped sh a -sconstant x = coerce fromPrimitive (sconstantP @sh x) +sconstant :: (Storable a, PrimElt a) => ShS sh -> a -> Shaped sh a +sconstant sh x = coerce fromPrimitive (sconstantP sh x) -sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat = slift $ \_ -> X.slice i n +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = +  let _ :$$ sh = sshape arr +  in slift (n :$$ sh) (\_ -> X.slice i n) arr -srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 = slift $ \_ -> X.rev1 +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr -sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a) -         => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) -  | Dict <- lemKnownMapJust (Proxy @sh) -  , Dict <- lemKnownMapJust (Proxy @sh') -  = Shaped (mreshape (shCvtSX sh') arr) +sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) -sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a -sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr +sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) -sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a -sasXArrayPrim (Shaped arr) = masXArrayPrim arr +sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) -sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP = Shaped . mfromXArrayPrimP +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr) -sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim = Shaped . mfromXArrayPrim --} +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) | 
