diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-19 12:55:58 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-19 12:55:58 +0200 |
commit | 390af124e6cb50c4a2cd9006662fc26eef02889a (patch) | |
tree | a72140daab1d4e63bb9d1ae06ca71c46b873bd07 /src/Data/Array | |
parent | b22366401ee4bfead4ba6789937acdb7274d175c (diff) |
Some IsList instances
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Mixed.hs | 48 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 78 |
2 files changed, 111 insertions, 15 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 31a4e69..398a3de 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -30,6 +30,7 @@ import Data.Bifunctor (first) import Data.Coerce import Data.Functor.Const import Data.Kind +import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Bool import Data.Type.Equality @@ -114,6 +115,9 @@ foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m foldListX _ ZX = mempty foldListX f (x ::% xs) = f x <> foldListX f xs +lengthListX :: ListX sh f -> Int +lengthListX = getSum . foldListX (\_ -> Sum 1) + showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS showListX f l = showString "[" . go "" l . showString "]" where @@ -193,8 +197,7 @@ instance Functor (ShX sh) where fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) lengthShX :: ShX sh i -> Int -lengthShX ZSX = 0 -lengthShX (_ :$% sh) = 1 + lengthShX sh +lengthShX (ShX l) = lengthListX l -- | The part of a shape that is statically known. @@ -218,6 +221,9 @@ infixr 3 :!% instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l +lengthStaticShX :: StaticShX sh -> Int +lengthStaticShX (StaticShX l) = lengthListX l + -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. @@ -228,36 +234,48 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX --- | Very untyped; length is checked at runtime. +-- | Very untyped: only length is checked (at runtime). instance KnownShX sh => IsList (ListX sh (Const i)) where type Item (ListX sh (Const i)) = i - fromList = go (knownShX @sh) + fromList topl = go (knownShX @sh) topl where go :: StaticShX sh' -> [i] -> ListX sh' (Const i) go ZKX [] = ZX go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error "IsList(ListX): Mismatched list length" + go _ _ = error $ "IsList(ListX): Mismatched list length (type says " + ++ show typelen ++ ", list has length " + ++ show (length topl) ++ ")" + where typelen = let StaticShX l = knownShX @sh in lengthListX l toList = go where go :: ListX sh' (Const i) -> [i] go ZX = [] go (Const i ::% is) = i : go is --- | Very untyped; length is checked at runtime, and index bounds are *not checked*. +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShX sh => IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . fromList toList (IxX l) = toList l --- | Very untyped; length is checked at runtime, and known dimensions are *not checked*. --- instance KnownShX sh => IsList (ShX sh i) where --- type Item (ShX sh i) = i --- fromList = ShX . fmapListX (\(Const i) -> _) . fromList --- toList = go --- where --- go :: ShX sh' i -> [i] --- go ZSX = [] --- go (Const i :$% is) = i : go is +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList = ShX . go (knownShX @sh) + where + go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) + go ZKX [] = ZX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn ::% go sh is + | otherwise = error $ "IsList(ShX): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is + go _ _ = error "IsList(ShX): Mismatched list length" + toList = go + where + go :: ShX sh' Int -> [Int] + go ZSX = [] + go (smn :$% sh) = fromSMayNat' smn : go sh type family Rank sh where diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 8777960..a61e7d6 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -85,11 +85,14 @@ import Data.Foldable (toList) import Data.Functor.Const import Data.Kind import Data.List.NonEmpty (NonEmpty(..)) +import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) +import GHC.IsList (IsList) +import qualified GHC.IsList as IsList import GHC.TypeLits import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce @@ -267,6 +270,34 @@ instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = showListR shows l +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList = go (SNat @n) + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error "IsList(ListR): Mismatched list length" + toList = go + where + go :: ListR n' i -> [i] + go ZR = [] + go (i ::: is) = i : go is + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList (IxR idx) = IsList.toList idx + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList (ShR idx) = IsList.toList idx + + type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -360,6 +391,53 @@ infixr 3 :$$ instance Show (ShS sh) where showsPrec _ (ShS l) = showListS (shows . fromSNat) l +lengthShS :: ShS sh -> Int +lengthShS (ShS l) = getSum (foldListS (\_ -> Sum 1) l) + + +-- | Untyped: length is checked at runtime. +instance KnownShS sh => IsList (ListS sh (Const i)) where + type Item (ListS sh (Const i)) = i + fromList topl = go (knownShS @sh) topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "IsList(ListS): Mismatched list length (type says " + ++ show (lengthShS (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = go + where + go :: ListS sh' (Const i) -> [i] + go ZS = [] + go (Const i ::$ is) = i : go is + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = IxS . IsList.fromList + toList (IxS idx) = IsList.toList idx + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList topl = ShS (go (knownShS @sh) topl) + where + go :: ShS sh' -> [Int] -> ListS sh' SNat + go ZSS [] = ZS + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = sn ::$ go sh is + | otherwise = error $ "IsList(ShS): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "IsList(ShS): Mismatched list length (type says " + ++ show (lengthShS (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = go + where + go :: ShS sh' -> [Int] + go ZSS = [] + go (sn :$$ sh) = fromSNat' sn : go sh + -- | Wrapper type used as a tag to attach instances on. The instances on arrays -- of @'Primitive' a@ are more polymorphic than the direct instances for arrays |