From a4ef1be4300872b7a4647a4074cf88294aa905e5 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 11 Feb 2025 17:57:06 +0100
Subject: Type-check keys, and provide unsafe versions of read ops

---
 src/Data/Dependent/EnumMap/Strict/Internal.hs | 101 +++++++++++++++++++++-----
 src/Data/Dependent/EnumMap/Strict/Unsafe.hs   |  44 +++++++++++
 2 files changed, 126 insertions(+), 19 deletions(-)
 create mode 100644 src/Data/Dependent/EnumMap/Strict/Unsafe.hs

(limited to 'src/Data/Dependent')

diff --git a/src/Data/Dependent/EnumMap/Strict/Internal.hs b/src/Data/Dependent/EnumMap/Strict/Internal.hs
index c2dcf95..12fb602 100644
--- a/src/Data/Dependent/EnumMap/Strict/Internal.hs
+++ b/src/Data/Dependent/EnumMap/Strict/Internal.hs
@@ -3,6 +3,7 @@
 {-# LANGUAGE RankNTypes #-}
 {-# LANGUAGE QuantifiedConstraints #-}
 {-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
 {-# LANGUAGE TypeFamilies #-}
 module Data.Dependent.EnumMap.Strict.Internal where
 
@@ -10,7 +11,9 @@ import Data.Bifunctor (bimap)
 import Data.Dependent.Sum
 import qualified Data.IntMap.Strict as IM
 import Data.Kind (Type)
+import Data.Proxy
 import Data.Some
+import Data.Type.Equality
 import Text.Show (showListWith)
 import Unsafe.Coerce (unsafeCoerce)
 
@@ -85,38 +88,67 @@ insertWithKey f k v (DEnumMap m) =
 delete :: Enum1 k => k a -> DEnumMap k v -> DEnumMap k v
 delete k (DEnumMap m) = DEnumMap (IM.delete (fst (fromEnum1 k)) m)
 
-adjust :: Enum1 k => (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
-adjust f k (DEnumMap m) = DEnumMap (IM.adjust (\(KV k' v) -> KV k' (f (coe1 v))) (fst (fromEnum1 k)) m)
+adjust :: (Enum1 k, TestEquality k) => (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
+adjust = adjust' typeCheckK
+
+adjustUnsafe :: Enum1 k => (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
+adjustUnsafe = adjust' don'tCheckK
+
+adjust' :: Enum1 k => Checker k a -> (v a -> v a) -> k a -> DEnumMap k v -> DEnumMap k v
+adjust' ch f k (DEnumMap m) =
+  let (i, _) = fromEnum1 k
+  in DEnumMap (IM.adjust (\(KV inf v) -> ch i k inf $ KV inf (f (coe1 v))) i m)
 
 -- adjustWithKey
 -- update
 -- updateWithKey
 -- updateLookupWithKey
 
-alter :: forall k v a. Enum1 k => (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
-alter f k (DEnumMap m) = DEnumMap (IM.alter f' i m)
+alter :: forall k v a. (Enum1 k, TestEquality k) => (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
+alter = alter' typeCheckK
+
+alterUnsafe :: forall k v a. Enum1 k => (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
+alterUnsafe = alter' don'tCheckK
+
+alter' :: forall k v a. Enum1 k => Checker k a -> (Maybe (v a) -> Maybe (v a)) -> k a -> DEnumMap k v -> DEnumMap k v
+alter' ch f k (DEnumMap m) = DEnumMap (IM.alter f' i m)
   where
     (i, inf) = fromEnum1 k
 
     f' :: Maybe (KV k v) -> Maybe (KV k v)
     f' Nothing = KV inf <$> f Nothing
-    f' (Just (KV _ v)) = KV inf <$> f (Just (coe1 v))
+    f' (Just (KV inf' v)) = ch i k inf' $ KV inf <$> f (Just (coe1 v))
 
 -- alterF
 
 -- * Query
 -- ** Lookup
 
-lookup :: Enum1 k => k a -> DEnumMap k v -> Maybe (v a)
-lookup k (DEnumMap m) = (\(KV _ v) -> coe1 v) <$> IM.lookup (fst (fromEnum1 k)) m
+lookup :: (Enum1 k, TestEquality k) => k a -> DEnumMap k v -> Maybe (v a)
+lookup = lookup' typeCheckK
+
+lookupUnsafe :: Enum1 k => k a -> DEnumMap k v -> Maybe (v a)
+lookupUnsafe = lookup' don'tCheckK
+
+lookup' :: Enum1 k => Checker k a -> k a -> DEnumMap k v -> Maybe (v a)
+lookup' ch k (DEnumMap m) =
+  let (i, _) = fromEnum1 k
+  in (\(KV inf v) -> ch i k inf $ coe1 v) <$> IM.lookup i m
 
 -- (!?)
 -- (!)
 
-findWithDefault :: Enum1 k => v a -> k a -> DEnumMap k v -> v a
-findWithDefault def k (DEnumMap m) =
-  case IM.findWithDefault (KV undefined def) (fst (fromEnum1 k)) m of
-    KV _ v -> coe1 v
+findWithDefault :: (Enum1 k, TestEquality k) => v a -> k a -> DEnumMap k v -> v a
+findWithDefault = findWithDefault' typeCheckK
+
+findWithDefaultUnsafe :: Enum1 k => v a -> k a -> DEnumMap k v -> v a
+findWithDefaultUnsafe = findWithDefault' don'tCheckK
+
+findWithDefault' :: Enum1 k => Checker k a -> v a -> k a -> DEnumMap k v -> v a
+findWithDefault' ch def k (DEnumMap m) =
+  let (i, _) = fromEnum1 k
+  in case IM.findWithDefault (KV undefined def) i m of
+       KV inf' v -> ch i k inf' $ coe1 v
 
 member :: Enum1 k => k a -> DEnumMap k v -> Bool
 member k (DEnumMap m) = IM.member (fst (fromEnum1 k)) m
@@ -139,15 +171,27 @@ size (DEnumMap m) = IM.size m
 
 -- ** Union
 
-union :: DEnumMap k v -> DEnumMap k v -> DEnumMap k v
-union (DEnumMap m1) (DEnumMap m2) = DEnumMap (IM.union m1 m2)
+union :: (Enum1 k, TestEquality k) => DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+union = unionWith const  -- if we're checking, we need unionWith anyway, so might as well just delegate here already
+
+-- in the unsafe variant, we can make do with IM.union, which is slightly faster than IM.unionWith, so let's specialise
+unionUnsafe :: DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionUnsafe (DEnumMap m1) (DEnumMap m2) = DEnumMap (IM.union m1 m2)
 
-unionWith :: forall k v. (forall a. v a -> v a -> v a)
-          -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
-unionWith f (DEnumMap m1) (DEnumMap m2) = DEnumMap (IM.unionWith f' m1 m2)
+unionWith :: (Enum1 k, TestEquality k)
+          => (forall a. v a -> v a -> v a) -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionWith f (m1 :: DEnumMap k v) = unionWith' (typeCheckSK (Proxy @k)) f m1
+
+unionWithUnsafe :: (forall a. v a -> v a -> v a) -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionWithUnsafe f (m1 :: DEnumMap k v) = unionWith' (don'tCheckSK (Proxy @k)) f m1
+
+unionWith' :: CheckerSplit k
+           -> (forall a. v a -> v a -> v a)
+           -> DEnumMap k v -> DEnumMap k v -> DEnumMap k v
+unionWith' ch f (DEnumMap m1 :: DEnumMap k v) (DEnumMap m2) = DEnumMap (IM.unionWithKey f' m1 m2)
   where
-    f' :: KV k v -> KV k v -> KV k v
-    f' (KV inf v1) (KV _ v2) = KV inf (f v1 (coe1 v2))
+    f' :: Int -> KV k v -> KV k v -> KV k v
+    f' i (KV inf1 v1) (KV inf2 v2) = ch i inf1 inf2 $ KV inf1 (f v1 (coe1 v2))
 
 -- unionWithKey
 -- unions
@@ -294,7 +338,26 @@ maxViewWithKey (DEnumMap m) =
     <$> IM.maxViewWithKey m
 
 
--- * Unsafe helpers
+-- * Helpers
 
 coe1 :: v a -> v b
 coe1 = unsafeCoerce
+
+type CheckerSplit k = forall r. Int -> Enum1Info k -> Enum1Info k -> r -> r
+
+typeCheckSK :: forall k proxy. (Enum1 k, TestEquality k) => proxy k -> CheckerSplit k
+typeCheckSK _ i inf1 inf2 = case toEnum1 @k i inf1 of Some k -> typeCheckK i k inf2
+
+don'tCheckSK :: proxy k -> CheckerSplit k
+don'tCheckSK _ _ _ _ = id
+
+type Checker k a = forall r. Int -> k a -> Enum1Info k -> r -> r
+
+typeCheckK :: (Enum1 k, TestEquality k) => Checker k a
+typeCheckK i k1 inf cont
+  | Some k2 <- toEnum1 i inf
+  , Just Refl <- testEquality k1 k2 = cont
+  | otherwise = errorWithoutStackTrace "DEnumMap: keys with same Int but different types"
+
+don'tCheckK :: Checker k a
+don'tCheckK _ _ _ = id
diff --git a/src/Data/Dependent/EnumMap/Strict/Unsafe.hs b/src/Data/Dependent/EnumMap/Strict/Unsafe.hs
new file mode 100644
index 0000000..4d4a9eb
--- /dev/null
+++ b/src/Data/Dependent/EnumMap/Strict/Unsafe.hs
@@ -0,0 +1,44 @@
+{-|
+These are variants of the functions in "Data.Dependent.EnumMap.Strict" that do
+not type-check keys: they do not check that you don't create two keys with the
+same 'Int' and different types. As a result, these functions do not have a
+'Data.Type.Equality.TestEquality' constraint, and are faster.
+
+Be aware though, because one can easily create @unsafeCoerce@ with this API:
+
+@
+{-# LANGUAGE ScopedTypeVariables TypeFamilies #-}
+
+import qualified Data.Dependent.EnumMap.Strict as DE
+import qualified Data.Dependent.EnumMap.Strict.Unsafe as DEU
+
+import Data.Functor.Identity
+import Data.Maybe
+import Data.Some
+
+data Foo a = Foo Int
+  deriving (Show)
+
+instance DE.Enum1 Foo where
+  type Enum1Info Foo = ()
+  fromEnum1 (Foo i) = (i, ())
+  toEnum1 i () = Some (Foo i)
+
+unsafe :: forall a b. a -> b
+unsafe x = runIdentity $ fromJust $
+             DEU.lookupUnsafe (Foo 1 :: Foo b) $
+               DE.singleton (Foo 1 :: Foo a) (Identity x)
+@
+
+-}
+module Data.Dependent.EnumMap.Strict.Unsafe (
+  adjustUnsafe,
+  alterUnsafe,
+  lookupUnsafe,
+  findWithDefaultUnsafe,
+  unionUnsafe,
+  unionWithUnsafe,
+) where
+
+import Prelude ()
+import Data.Dependent.EnumMap.Strict.Internal
-- 
cgit v1.2.3-70-g09d2