module Main where import Data.Bits import qualified Data.IntMap.Strict as IM import Data.IntMap.Strict (IntMap) import Data.List (foldl') import Input subsets :: [a] -> [[a]] subsets [] = [[]] subsets (x:xs) = subsets xs ++ map (x:) (subsets xs) data Mask = Mask Int Int deriving (Show) applyMask :: Mask -> Int -> Int applyMask (Mask set reset) n = (n .|. set) .&. complement reset data Instr = SetMask Mask | Store Int Int deriving (Show) parseInstr :: String -> Instr parseInstr str | str !! 1 == 'a' = SetMask (parseMask (drop 7 str)) | [(idx, rest)] <- reads (drop 4 str) = Store idx (read (drop 4 rest)) | otherwise = error "Can't parse Instr" parseMask :: String -> Mask parseMask s = let pairs = zip s (map bit [35,34..0]) setbits = sum . map snd . filter ((== '1') . fst) $ pairs resetbits = sum . map snd . filter ((== '0') . fst) $ pairs in Mask setbits resetbits data State = State { sMask :: Mask, sMem :: IntMap Int } deriving (Show) exec1 :: State -> Instr -> State exec1 (State _ mem) (SetMask mask) = State mask mem exec1 (State mask mem) (Store idx val) = State mask (IM.insert idx (applyMask mask val) mem) expandWithMask :: Mask -> Int -> [Int] expandWithMask (Mask ones zeros) n = let floating = (bit 36 - 1) .&. complement (ones .|. zeros) n' = (n .|. ones) .&. complement floating singlebits = [bit i | i <- [0..35], testBit floating i] in map ((n' .|.) . sum) (subsets singlebits) exec2 :: State -> Instr -> State exec2 (State _ mem) (SetMask mask) = State mask mem exec2 (State mask mem) (Store idx val) = State mask (foldr (\idx' -> IM.insert idx' val) mem (expandWithMask mask idx)) main :: IO () main = do input <- map parseInstr <$> getInput 14 print (sum . IM.elems . sMem $ foldl' exec1 (State (Mask 0 0) mempty) input) print (sum . IM.elems . sMem $ foldl' exec2 (State (Mask 0 0) mempty) input)