summaryrefslogtreecommitdiff
path: root/2021/9.fut
blob: 10b5f82a8d51c0f6e553fe651d45b89656bf16d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
let zip2d 'a 'b [h] [w] (a: [h][w]a) (b: [h][w]b) : [h][w](a, b) =
  map2 (map2 (\x y -> (x, y))) a b

let fst 'a 'b (t: (a, b)) : a = let (x, _) = t in x
let snd 'a 'b (t: (a, b)) : b = let (_, y) = t in y

let get 'a [h] [w] (def: a) (field: [h][w]a) (y: i64) (x: i64) : a =
  if x < 0 || y < 0 || x >= w || y >= h
    then def
    else field[y, x]

let stencil 'a 'b [h] [w] (def: a) (field: [h][w]a) (f: a -> (a,a,a,a) -> b) : [h][w]b =
  tabulate h (\y ->
    tabulate w (\x ->
      f (get def field y x)
        (get def field (y-1) x
        ,get def field y (x+1)
        ,get def field (y+1) x
        ,get def field y (x-1))))

let red4 'a 'b (f: a -> b) (g: b -> b -> b) (tup: (a, a, a, a)) : b =
  let (a, b, c, d) = tup
  in g (g (g (f a) (f b)) (f c)) (f d)

let markpits [h] [w] (field: [h][w]u8) : [h][w]u8 =
  stencil 9 field
    (\v env -> if red4 (\d -> v < d) (&&) env
                 then v + 1
                 else 255)

let pits [h] [w] (field: [h][w]u8) : []u8 =
  filter (!= 255) (flatten (markpits field))

let grow [h] [w] (field: [h][w]u8) (basin: [h][w]bool) : [h][w]bool =
  stencil (9, false) (zip2d field basin) (\(d, v) env -> d < 9 && (v || red4 snd (||) env))

let floodfill [h] [w] (field: [h][w]u8) : *[h][w]bool =
  let basin0 = map (map (!= 255)) (markpits field)
  let basin1 = grow field basin0
  let changed a b = reduce_comm (||) false (flatten (map2 (map2 (!=)) a b))
  let (res, _) =
        loop (a, b) = (basin0, basin1) while changed a b do
          (b, grow field b)
  in res

let merge [h] [w] (ids: [h][w]i64) : [h][w]i64 =
  stencil i64.highest ids
    (\n env -> if n == i64.highest
                 then i64.highest
                 else i64.min n (red4 id i64.min env))

let markbasins [h] [w] (basins: [h][w]bool) : *[h][w]i64 =
  let numbers = unflatten h w (iota (h * w))
  let ids0 = map2 (map2 (\n b -> if b then n else i64.highest)) numbers basins
  let ids1 = merge ids0
  let changed a b = reduce_comm (||) false (flatten (map2 (map2 (!=)) a b))
  let (res, _) =
        loop (a, b) = (ids0, ids1) while changed a b do
          (b, merge b)
  in res

let collect_used_ids [h] [w] (ids: [h][w]i64) : []i64 =
  let nums = flatten ids
  let bitmap = scatter (tabulate (h * w) (\_ -> false)) nums (map (\_ -> true) nums)
  in map snd (filter fst (zip bitmap (iota (h * w))))

let ipow (base: i64) (exponent: i64) : i64 =
  loop res = 1 for _i < exponent do res * base

entry main [h] [w] (field: [h][w]u8) : (i32, i64) =
  let ids = map (map (\n -> if n == i64.highest then -1 else n))
                (markbasins (floodfill field))
  let used = collect_used_ids ids
  let sizes = map (\n -> reduce_comm (+) 0 (map ((==n) >-> i64.bool) (flatten ids))) used
  let sizes_map = reduce_by_index (tabulate (h * w) (\_ -> 0))
                                  (+) 0
                                  sizes (map (\_ -> 1) sizes)
  let count_size_cumulcount =
        zip sizes_map (iota (h * w))
        |> filter (\(cnt, _sz) -> cnt > 0)
        |> reverse
        |> (\arr -> zip arr (scan (+) 0 (map fst arr)))
        |> take 3
  let wanted_count =
        map (\((cnt, _sz), cmlcnt) -> i64.max 0 (i64.min cnt (3 - (cmlcnt - cnt))))
            count_size_cumulcount
  let part2 = reduce_comm (*) 1
                (map2 (\wc ((_cnt, sz), _cmlcnt) -> ipow sz wc)
                      wanted_count count_size_cumulcount)
  in (reduce_comm (+) 0 (map i32.u8 (pits field))
     ,part2)