diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 168 |
1 files changed, 125 insertions, 43 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index fbfc7f5..0d90e91 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,13 +1,12 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -18,9 +17,11 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -37,17 +38,22 @@ import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) +import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -98,13 +104,15 @@ listsEqual (n ::$ sh) (m ::$ sh') = Just Refl listsEqual _ _ = Nothing +{-# INLINE listsFmap #-} listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs +{-# INLINE listsFoldMap #-} +listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFoldMap _ ZS = mempty +listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" @@ -114,15 +122,29 @@ listsShow f l = showString "[" . go "" l . showString "]" go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsLength :: ListS sh f -> Int -listsLength = getSum . listsFold (\_ -> Sum 1) +listsLength = getSum . listsFoldMap (\_ -> Sum 1) listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) +listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList topsh topl = go topsh topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "listsFromList: Mismatched list length (type says " + ++ show (shsLength topsh) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE listsToList #-} listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is +listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListS sh (Const i) -> is + go ZS = nil + go (Const i ::$ is) = i `cons` go is + in go list) listsHead :: ListS (n : sh) f -> f n listsHead (i ::$ _) = i @@ -144,14 +166,13 @@ listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = - Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +{-# INLINE listsZipWith #-} listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = - f i j ::$ listsZipWith f is js +listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS @@ -180,11 +201,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) @@ -193,6 +212,8 @@ newtype IxS sh i = IxS (ListS sh (Const i)) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) @@ -203,6 +224,8 @@ infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -213,10 +236,18 @@ instance Show i => Show (IxS sh i) where #endif instance Functor (IxS sh) where + {-# INLINE fmap #-} fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l + {-# INLINE foldMap #-} + foldMap f (IxS l) = listsFoldMap (f . getConst) l + {-# INLINE foldr #-} + foldr _ z ZIS = z + foldr f z (x :.$ xs) = f x (foldr f z xs) + toList = ixsToList + null ZIS = False + null _ = True instance NFData i => NFData (IxS sh i) @@ -226,6 +257,13 @@ ixsLength (IxS l) = listsLength l ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank (IxS l) = listsRank l +ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i +ixsFromList = coerce (listsFromList @_ @i) + +{-# INLINEABLE ixsToList #-} +ixsToList :: forall sh i. IxS sh i -> [i] +ixsToList = coerce (listsToList @_ @i) + ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh @@ -242,14 +280,21 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) -ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js -ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +{-# INLINE ixsZipWith #-} +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js @@ -257,6 +302,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- * Shaped shapes + -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you @@ -264,7 +311,10 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord, Generic) + deriving (Generic) + +instance Eq (ShS sh) where _ == _ = True +instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS @@ -309,9 +359,28 @@ shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh +-- | This is a partial @const@ that fails when the second argument +-- doesn't match the first. +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList topsh topl = go topsh topl `seq` topsh + where + go :: ShS sh' -> [Int] -> () + go ZSS [] = () + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = go sh is + | otherwise = error $ "shsFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "shsFromList: Mismatched list length (type says " + ++ show (shsLength topsh) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh +shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> + let go :: ShS sh -> is + go ZSS = nil + go (sn :$$ sh) = fromSNat' sn `cons` go sh + in go topsh) shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list @@ -356,7 +425,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -366,18 +435,38 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] + where + suffixes = drop 1 (scanr (*) 1 (shsToList sh)) + + fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i + fromLin ZSS _ _ = ZIS + fromLin (_ :$$ sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh' + in fromIntegral (I# q#) :.$ fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = listsFromList (knownShS @sh) toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. @@ -389,15 +478,8 @@ instance KnownShS sh => IsList (IxS sh i) where -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = shsFromList (knownShS @sh) toList = shsToList + +$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) +{-# INLINEABLE ixsFromLinear #-} |
