summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-11 17:57:06 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-11 17:57:06 +0100
commita4ef1be4300872b7a4647a4074cf88294aa905e5 (patch)
tree587581cbd8b4e02111369d81d530736fe4c062fa /src
parent963378040c0a6b1819912b406de192ca9ddf0773 (diff)
Type-check keys, and provide unsafe versions of read ops
Diffstat (limited to 'src')
-rw-r--r--src/Data/Dependent/EnumMap/Strict/Internal.hs101
-rw-r--r--src/Data/Dependent/EnumMap/Strict/Unsafe.hs44
2 files changed, 126 insertions, 19 deletions
diff --git a/src/Data/Dependent/EnumMap/Strict/Internal.hs b/src/Data/Dependent/EnumMap/Strict/Internal.hs
index c2dcf95..12fb602 100644
--- a/src/Data/Dependent/EnumMap/Strict/Internal.hs
+++ b/src/Data/Dependent/EnumMap/Strict/Internal.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Dependent.EnumMap.Strict.Internal where
@@ -10,7 +11,9 @@ import Data.Bifunctor (bimap)
import Data.Dependent.Sum
import qualified Data.IntMap.Strict as IM
import Data.Kind (Type)
+import Data.Proxy
import Data.Some
+import Data.Type.Equality
import Text.Show (showListWith)
import Unsafe.Coerce (unsafeCoerce)
@@ -85,38 +88,67 @@ insertWithKey f k v (DEnumMap m) =
delete :: Enum1 k => k a -> DEnumMap k v -> DEnumMap k v
delete k (DEnumMap m) = DEnumMap (IM.delete (fst (fromEnum1 k)) m)
-adjust :: Enum1 k => (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
-adjust f k (DEnumMap m) = DEnumMap (IM.adjust (\(KV k' v) -> KV k' (f (coe1 v))) (fst (fromEnum1 k)) m)
+adjust :: (Enum1 k, TestEquality k) => (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
+adjust = adjust' typeCheckK
+
+adjustUnsafe :: Enum1 k => (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
+adjustUnsafe = adjust' don'tCheckK
+
+adjust' :: Enum1 k => Checker k a -> (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
+adjust' ch f k (DEnumMap m) =
+ let (i, _) = fromEnum1 k
+ in DEnumMap (IM.adjust (\(KV inf v) -> ch i k inf $ KV inf (f (coe1 v))) i m)
-- adjustWithKey
-- update
-- updateWithKey
-- updateLookupWithKey
-alter :: forall k v a. Enum1 k => (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
-alter f k (DEnumMap m) = DEnumMap (IM.alter f' i m)
+alter :: forall k v a. (Enum1 k, TestEquality k) => (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
+alter = alter' typeCheckK
+
+alterUnsafe :: forall k v a. Enum1 k => (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
+alterUnsafe = alter' don'tCheckK
+
+alter' :: forall k v a. Enum1 k => Checker k a -> (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
+alter' ch f k (DEnumMap m) = DEnumMap (IM.alter f' i m)
where
(i, inf) = fromEnum1 k
f' :: Maybe (KV k v) -> Maybe (KV k v)
f' Nothing = KV inf <$> f Nothing
- f' (Just (KV _ v)) = KV inf <$> f (Just (coe1 v))
+ f' (Just (KV inf' v)) = ch i k inf' $ KV inf <$> f (Just (coe1 v))
-- alterF
-- * Query
-- ** Lookup
-lookup :: Enum1 k => k a -> DEnumMap k v -> Maybe (v a)
-lookup k (DEnumMap m) = (\(KV _ v) -> coe1 v) <$> IM.lookup (fst (fromEnum1 k)) m
+lookup :: (Enum1 k, TestEquality k) => k a -> DEnumMap k v -> Maybe (v a)
+lookup = lookup' typeCheckK
+
+lookupUnsafe :: Enum1 k => k a -> DEnumMap k v -> Maybe (v a)
+lookupUnsafe = lookup' don'tCheckK
+
+lookup' :: Enum1 k => Checker k a -> k a -> DEnumMap k v -> Maybe (v a)
+lookup' ch k (DEnumMap m) =
+ let (i, _) = fromEnum1 k
+ in (\(KV inf v) -> ch i k inf $ coe1 v) <$> IM.lookup i m
-- (!?)
-- (!)
-findWithDefault :: Enum1 k => v a -> k a -> DEnumMap k v -> v a
-findWithDefault def k (DEnumMap m) =
- case IM.findWithDefault (KV undefined def) (fst (fromEnum1 k)) m of
- KV _ v -> coe1 v
+findWithDefault :: (Enum1 k, TestEquality k) => v a -> k a -> DEnumMap k v -> v a
+findWithDefault = findWithDefault' typeCheckK
+
+findWithDefaultUnsafe :: Enum1 k => v a -> k a -> DEnumMap k v -> v a
+findWithDefaultUnsafe = findWithDefault' don'tCheckK
+
+findWithDefault' :: Enum1 k => Checker k a -> v a -> k a -> DEnumMap k v -> v a
+findWithDefault' ch def k (DEnumMap m) =
+ let (i, _) = fromEnum1 k
+ in case IM.findWithDefault (KV undefined def) i m of
+ KV inf' v -> ch i k inf' $ coe1 v
member :: Enum1 k => k a -> DEnumMap k v -> Bool
member k (DEnumMap m) = IM.member (fst (fromEnum1 k)) m
@@ -139,15 +171,27 @@ size (DEnumMap m) = IM.size m
-- ** Union
-union :: DEnumMap k v -> DEnumMap k v -> DEnumMap k v
-union (DEnumMap m1) (DEnumMap m2) = DEnumMap (IM.union m1 m2)
+union :: (Enum1 k, TestEquality k) => DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+union = unionWith const -- if we're checking, we need unionWith anyway, so might as well just delegate here already
+
+-- in the unsafe variant, we can make do with IM.union, which is slightly faster than IM.unionWith, so let's specialise
+unionUnsafe :: DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionUnsafe (DEnumMap m1) (DEnumMap m2) = DEnumMap (IM.union m1 m2)
-unionWith :: forall k v. (forall a. v a -> v a -> v a)
- -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
-unionWith f (DEnumMap m1) (DEnumMap m2) = DEnumMap (IM.unionWith f' m1 m2)
+unionWith :: (Enum1 k, TestEquality k)
+ => (forall a. v a -> v a -> v a) -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionWith f (m1 :: DEnumMap k v) = unionWith' (typeCheckSK (Proxy @k)) f m1
+
+unionWithUnsafe :: (forall a. v a -> v a -> v a) -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionWithUnsafe f (m1 :: DEnumMap k v) = unionWith' (don'tCheckSK (Proxy @k)) f m1
+
+unionWith' :: CheckerSplit k
+ -> (forall a. v a -> v a -> v a)
+ -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionWith' ch f (DEnumMap m1 :: DEnumMap k v) (DEnumMap m2) = DEnumMap (IM.unionWithKey f' m1 m2)
where
- f' :: KV k v -> KV k v -> KV k v
- f' (KV inf v1) (KV _ v2) = KV inf (f v1 (coe1 v2))
+ f' :: Int -> KV k v -> KV k v -> KV k v
+ f' i (KV inf1 v1) (KV inf2 v2) = ch i inf1 inf2 $ KV inf1 (f v1 (coe1 v2))
-- unionWithKey
-- unions
@@ -294,7 +338,26 @@ maxViewWithKey (DEnumMap m) =
<$> IM.maxViewWithKey m
--- * Unsafe helpers
+-- * Helpers
coe1 :: v a -> v b
coe1 = unsafeCoerce
+
+type CheckerSplit k = forall r. Int -> Enum1Info k -> Enum1Info k -> r -> r
+
+typeCheckSK :: forall k proxy. (Enum1 k, TestEquality k) => proxy k -> CheckerSplit k
+typeCheckSK _ i inf1 inf2 = case toEnum1 @k i inf1 of Some k -> typeCheckK i k inf2
+
+don'tCheckSK :: proxy k -> CheckerSplit k
+don'tCheckSK _ _ _ _ = id
+
+type Checker k a = forall r. Int -> k a -> Enum1Info k -> r -> r
+
+typeCheckK :: (Enum1 k, TestEquality k) => Checker k a
+typeCheckK i k1 inf cont
+ | Some k2 <- toEnum1 i inf
+ , Just Refl <- testEquality k1 k2 = cont
+ | otherwise = errorWithoutStackTrace "DEnumMap: keys with same Int but different types"
+
+don'tCheckK :: Checker k a
+don'tCheckK _ _ _ = id
diff --git a/src/Data/Dependent/EnumMap/Strict/Unsafe.hs b/src/Data/Dependent/EnumMap/Strict/Unsafe.hs
new file mode 100644
index 0000000..4d4a9eb
--- /dev/null
+++ b/src/Data/Dependent/EnumMap/Strict/Unsafe.hs
@@ -0,0 +1,44 @@
+{-|
+These are variants of the functions in "Data.Dependent.EnumMap.Strict" that do
+not type-check keys: they do not check that you don't create two keys with the
+same 'Int' and different types. As a result, these functions do not have a
+'Data.Type.Equality.TestEquality' constraint, and are faster.
+
+Be aware though, because one can easily create @unsafeCoerce@ with this API:
+
+@
+{-# LANGUAGE ScopedTypeVariables TypeFamilies #-}
+
+import qualified Data.Dependent.EnumMap.Strict as DE
+import qualified Data.Dependent.EnumMap.Strict.Unsafe as DEU
+
+import Data.Functor.Identity
+import Data.Maybe
+import Data.Some
+
+data Foo a = Foo Int
+ deriving (Show)
+
+instance DE.Enum1 Foo where
+ type Enum1Info Foo = ()
+ fromEnum1 (Foo i) = (i, ())
+ toEnum1 i () = Some (Foo i)
+
+unsafe :: forall a b. a -> b
+unsafe x = runIdentity $ fromJust $
+ DEU.lookupUnsafe (Foo 1 :: Foo b) $
+ DE.singleton (Foo 1 :: Foo a) (Identity x)
+@
+
+-}
+module Data.Dependent.EnumMap.Strict.Unsafe (
+ adjustUnsafe,
+ alterUnsafe,
+ lookupUnsafe,
+ findWithDefaultUnsafe,
+ unionUnsafe,
+ unionWithUnsafe,
+) where
+
+import Prelude ()
+import Data.Dependent.EnumMap.Strict.Internal