summaryrefslogtreecommitdiff
path: root/2019/SmallIntSet.hs
blob: 182b850935c6f1744d8177e304447ba3e2a3de38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
module SmallIntSet (
    SmallIntSet,
    toList, fromList, size, empty, singleton, insert, member, notMember, union, (\\)
) where

import Data.Bits
import Data.List (intercalate)


nBits :: Int
nBits = finiteBitSize (undefined :: Int)

newtype SmallIntSet = SmallIntSet Int
  deriving (Eq, Ord)

instance Show SmallIntSet where
    show set = "{" ++ intercalate "," (map show (toList set)) ++ "}"

instance Semigroup SmallIntSet where
    s1 <> s2 = union s1 s2

toList :: SmallIntSet -> [Int]
toList (SmallIntSet bm) = [n | n <- [0..nBits-1], testBit bm n]

fromList :: [Int] -> SmallIntSet
fromList = foldr insert empty

size :: SmallIntSet -> Int
size (SmallIntSet bm) = popCount bm

empty :: SmallIntSet
empty = SmallIntSet 0

singleton :: Int -> SmallIntSet
singleton n = checkValid n `seq` SmallIntSet (1 `shiftL` n)

insert :: Int -> SmallIntSet -> SmallIntSet
insert n (SmallIntSet bm) = checkValid n `seq` SmallIntSet (bm .|. (1 `shiftL` n))

member :: Int -> SmallIntSet -> Bool
member n (SmallIntSet bm) = checkValid n `seq` (bm .&. (1 `shiftL` n)) /= 0

notMember :: Int -> SmallIntSet -> Bool
notMember n (SmallIntSet bm) = checkValid n `seq` (bm .&. (1 `shiftL` n)) == 0

union :: SmallIntSet -> SmallIntSet -> SmallIntSet
union (SmallIntSet b1) (SmallIntSet b2) = SmallIntSet (b1 .|. b2)

(\\) :: SmallIntSet -> SmallIntSet -> SmallIntSet
SmallIntSet b1 \\ SmallIntSet b2 = SmallIntSet (b1 .&. complement b2)

checkValid :: Int -> Bool
checkValid n | 0 <= n, n < nBits = True
             | otherwise = error $ "SmallIntSet bounds violated with " ++ show n