summaryrefslogtreecommitdiff
path: root/2019/18.hs
blob: a60421a9229c50ddbc5ec018bf481f78ca3b0e3c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
{-# LANGUAGE TupleSections #-}
module Main where

import Control.Monad
import qualified Data.Array.Unboxed as A
import qualified Data.Array.ST as STA
import Data.Char
import Data.List
import qualified Data.IntMap as IntMap
import Data.IntMap (IntMap)
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Maybe
import qualified Data.Set as Set
import Data.Set (Set)

-- import Debug.Trace

import Input
import qualified SmallIntSet as SIS
import SmallIntSet (SmallIntSet)


-- Considers a distance of '-1' to mean 'unconnected'.
-- Applies Floyd-Warshall.
transitiveClosure :: [Int] -> A.UArray (Int, Int) Int -> A.UArray (Int, Int) Int
transitiveClosure nodeList initMatrix =
    STA.runSTUArray $ do
        arr <- STA.thaw initMatrix
        forM_ nodeList $ \k ->
            forM_ nodeList $ \i ->
                forM_ nodeList $ \j -> do
                    dij <- STA.readArray arr (i, j)
                    dik <- STA.readArray arr (i, k)
                    dkj <- STA.readArray arr (k, j)
                    if dik /= -1 && dkj /= -1 && (dij == -1 || dik + dkj < dij)
                        then STA.writeArray arr (i, j) (dik + dkj)
                        else return ()
        return arr

type Pos = (Int, Int)
type Dir = (Int, Int)

reachableFrom :: Map Pos Char -> Pos -> Map Pos Int
reachableFrom bd startPos = go 0 (Set.singleton startPos) (Set.singleton startPos) Map.empty
  where
    go dist seen boundary result =
        let boundary' = [pos
                        | (x, y) <- Set.toList boundary
                        , (dx, dy) <- [(-1,0), (0,-1), (1,0), (0,1)]
                        , let pos = (x + dx, y + dy)
                        , bd Map.! pos /= '#'
                        , pos `Set.notMember` seen]
            (things, frees) = partition (\pos -> bd Map.! pos /= '.') boundary'
            result' = result <> Map.fromList (map (,dist+1) things)
        in if null frees
               then result'
               else go (dist + 1) (seen <> Set.fromList boundary') (Set.fromList frees) result'

-- [0  1  26   27 52]
-- [@, a...z,  A...Z]
data Implicit = Implicit (A.Array Int [Int])        -- edge list
                         (A.UArray (Int, Int) Int)  -- distance matrix, with closure taken
  deriving (Show)

implicitGraph :: Map Pos Char -> Pos -> Implicit
implicitGraph bd startPos =
    let posGraph = fst (go startPos Map.empty Set.empty)
        mapGraph = Map.mapKeys (bd Map.!) (Map.map (Map.mapKeys (bd Map.!)) posGraph)
        charToNode '@' = 0
        charToNode c | isLower c = 1 + ord c - ord 'a'
                     | isUpper c = 1 + 26 + ord c - ord 'A'
                     | otherwise = undefined
        arrGraph = A.accumArray (const id) [] (0, 2 * 26)
                        [(charToNode from, map charToNode (Map.keys tomap))
                        | (from, tomap) <- Map.assocs mapGraph]
        distArr = A.accumArray (const id) (-1) ((0, 0), (2 * 26, 2 * 26))
                        [((charToNode from, charToNode to), dist)
                        | (from, tomap) <- Map.assocs mapGraph
                        , (to, dist) <- Map.assocs tomap]
        nodeList = map charToNode (Map.keys mapGraph)
    in Implicit arrGraph (transitiveClosure nodeList distArr)
  where
    go :: Pos -> Map Pos (Map Pos Int) -> Set Pos -> (Map Pos (Map Pos Int), Set Pos)
    go curPos graph seen
        | curPos `Set.member` seen = (graph, seen)
        | otherwise =
            let reach = reachableFrom bd curPos
                newNodes = Map.keysSet reach Set.\\ seen
                graph' = Map.insert curPos reach graph
                seen' = Set.insert curPos seen
            in Set.foldl' (\(gr, sn) node -> go node gr sn) (graph', seen') newNodes

reachable :: Implicit -> SmallIntSet -> Int -> IntMap Int
reachable (Implicit graph distarr) keys start = snd (go 0 (SIS.singleton start) start IntMap.empty)
  where
    go dist seen at result =
        let nexts = filter (\c -> c `SIS.notMember` seen && (c <= 26 || (c - 26) `SIS.member` keys))
                           (graph A.! at)
            (nextPearls, nextNonpearls) = partition (\c -> 0 < c && c <= 26 && c `SIS.notMember` keys) nexts
            result' = result <> IntMap.fromList [(c, dist + distarr A.! (at, c)) | c <- nextPearls]
            seen' = seen <> SIS.fromList nexts
        in -- trace ("reachable-go at=" ++ show at ++ " dist=" ++ show dist ++ " nexts=" ++ show nexts ++ " (allnexts " ++ show (graph A.! at) ++ ")") $
           if null nexts
               then (seen, result')
               else foldl (\(sn, rs) c -> go (dist + distarr A.! (at, c)) sn c rs) (seen', result') nextNonpearls

searchBFS :: SmallIntSet -> Implicit -> (Int, [Int])
searchBFS allKeys implicit@(Implicit _ distarr) = go 0 (Set.singleton (heuristic 0 SIS.empty, 0, 0, SIS.empty, [])) Map.empty
  where
    -- pqueue: f-val, distance, node, keys, key order
    -- visited: node, keys => distance
    go :: Int -> Set (Int, Int, Int, SmallIntSet, [Int]) -> Map (Int, SmallIntSet) Int -> (Int, [Int])
    go ctr pqueue visited =
        let ((heurval, dist, curnode, keys, keyorder), newpqueue) = Set.deleteFindMin pqueue
            reach = reachable implicit keys curnode
            nextStates = [(dist + stepDist + heuristic stepC stepKeys, dist + stepDist, stepC, stepKeys, stepC : keyorder)
                         | (stepC, stepDist) <- IntMap.assocs reach
                         , let stepKeys = SIS.insert stepC keys
                           -- check that this next state is actually better than we've seen before
                         , maybe True (dist + stepDist <) (Map.lookup (stepC, stepKeys) visited)]
            visited' = Map.insert (curnode, keys) dist visited
            pqueue' = newpqueue <> Set.fromList nextStates
            result = 
               if IntMap.null reach
                   then if heurval == dist then (dist, keyorder) else error ("heurval - dist = " ++ show (heurval - dist) ++ " in terminal state!")
                   else go (ctr + 1) pqueue' visited'
        in -- (if ctr `rem` 20000 == 0 || IntMap.null reach
           --      then trace ("go #pqueue=" ++ show (Set.size pqueue) ++ " #visited=" ++ show (Map.size visited)
           --                  ++ "  curnode=" ++ show curnode ++ " dist=" ++ show dist ++ " heurval=" ++ show heurval ++ " keys=" ++ show keyorder
           --                  {- ++ "  next->" ++ show nextStates -})
           --      else id)
           result

    heuristic :: Int -> SmallIntSet -> Int
    heuristic _curnode keys =
        let remainKeys = allKeys SIS.\\ keys
            allDists = [distarr A.! (x, y) | x:xs <- tails (SIS.toList remainKeys), y <- xs]
            distLowerBound = sum (take (SIS.size remainKeys - 1) (sort allDists))
        in distLowerBound

main :: IO ()
main = do
    stringbd <- getInput 18
    let bd = Map.fromList [((x, y), c) | (y, row) <- zip [0..] stringbd, (x, c) <- zip [0..] row]
        startpos = fromJust (lookup '@' (map (\(x,y) -> (y,x)) (Map.assocs bd)))

    let imgraph = implicitGraph bd startpos
        allKeys = SIS.fromList [ord c - ord 'a' + 1 | c <- Map.elems bd, isLower c]
    print (fst (searchBFS allKeys imgraph))