diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 27 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 168 |
2 files changed, 136 insertions, 59 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index ddd44bf..98f1241 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -90,8 +90,8 @@ instance Elt a => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a) + mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l)) mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] mtoListOuter (M_Shaped arr) @@ -136,7 +136,7 @@ instance Elt a => Elt (Shaped sh a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" @@ -172,10 +172,10 @@ instance Elt a => Elt (Shaped sh a) where instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i + memptyArrayUnsafe sh | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i + memptyArrayUnsafe sh mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) @@ -203,15 +203,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where negate = liftShaped1 negate abs = liftShaped1 abs signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim" recip = liftShaped1 recip (/) = liftShaped2 (/) instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim" exp = liftShaped1 exp log = liftShaped1 log sqrt = liftShaped1 sqrt @@ -246,15 +246,10 @@ sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = shsFromShX (mshape arr) -- Needed already here, but re-exported in Data.Array.Nested.Convert. -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index fbfc7f5..0d90e91 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,13 +1,12 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -18,9 +17,11 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -37,17 +38,22 @@ import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) +import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -98,13 +104,15 @@ listsEqual (n ::$ sh) (m ::$ sh') = Just Refl listsEqual _ _ = Nothing +{-# INLINE listsFmap #-} listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs +{-# INLINE listsFoldMap #-} +listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFoldMap _ ZS = mempty +listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" @@ -114,15 +122,29 @@ listsShow f l = showString "[" . go "" l . showString "]" go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsLength :: ListS sh f -> Int -listsLength = getSum . listsFold (\_ -> Sum 1) +listsLength = getSum . listsFoldMap (\_ -> Sum 1) listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) +listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList topsh topl = go topsh topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "listsFromList: Mismatched list length (type says " + ++ show (shsLength topsh) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE listsToList #-} listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is +listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListS sh (Const i) -> is + go ZS = nil + go (Const i ::$ is) = i `cons` go is + in go list) listsHead :: ListS (n : sh) f -> f n listsHead (i ::$ _) = i @@ -144,14 +166,13 @@ listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = - Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +{-# INLINE listsZipWith #-} listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = - f i j ::$ listsZipWith f is js +listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS @@ -180,11 +201,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) @@ -193,6 +212,8 @@ newtype IxS sh i = IxS (ListS sh (Const i)) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) @@ -203,6 +224,8 @@ infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -213,10 +236,18 @@ instance Show i => Show (IxS sh i) where #endif instance Functor (IxS sh) where + {-# INLINE fmap #-} fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l + {-# INLINE foldMap #-} + foldMap f (IxS l) = listsFoldMap (f . getConst) l + {-# INLINE foldr #-} + foldr _ z ZIS = z + foldr f z (x :.$ xs) = f x (foldr f z xs) + toList = ixsToList + null ZIS = False + null _ = True instance NFData i => NFData (IxS sh i) @@ -226,6 +257,13 @@ ixsLength (IxS l) = listsLength l ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank (IxS l) = listsRank l +ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i +ixsFromList = coerce (listsFromList @_ @i) + +{-# INLINEABLE ixsToList #-} +ixsToList :: forall sh i. IxS sh i -> [i] +ixsToList = coerce (listsToList @_ @i) + ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh @@ -242,14 +280,21 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) -ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js -ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +{-# INLINE ixsZipWith #-} +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js @@ -257,6 +302,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- * Shaped shapes + -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you @@ -264,7 +311,10 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord, Generic) + deriving (Generic) + +instance Eq (ShS sh) where _ == _ = True +instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS @@ -309,9 +359,28 @@ shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh +-- | This is a partial @const@ that fails when the second argument +-- doesn't match the first. +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList topsh topl = go topsh topl `seq` topsh + where + go :: ShS sh' -> [Int] -> () + go ZSS [] = () + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = go sh is + | otherwise = error $ "shsFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "shsFromList: Mismatched list length (type says " + ++ show (shsLength topsh) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh +shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> + let go :: ShS sh -> is + go ZSS = nil + go (sn :$$ sh) = fromSNat' sn `cons` go sh + in go topsh) shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list @@ -356,7 +425,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -366,18 +435,38 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] + where + suffixes = drop 1 (scanr (*) 1 (shsToList sh)) + + fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i + fromLin ZSS _ _ = ZIS + fromLin (_ :$$ sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh' + in fromIntegral (I# q#) :.$ fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = listsFromList (knownShS @sh) toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. @@ -389,15 +478,8 @@ instance KnownShS sh => IsList (IxS sh i) where -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = shsFromList (knownShS @sh) toList = shsToList + +$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) +{-# INLINEABLE ixsFromLinear #-} |
