summaryrefslogtreecommitdiff
path: root/2021/9.fut
diff options
context:
space:
mode:
Diffstat (limited to '2021/9.fut')
-rw-r--r--2021/9.fut91
1 files changed, 91 insertions, 0 deletions
diff --git a/2021/9.fut b/2021/9.fut
new file mode 100644
index 0000000..10b5f82
--- /dev/null
+++ b/2021/9.fut
@@ -0,0 +1,91 @@
+let zip2d 'a 'b [h] [w] (a: [h][w]a) (b: [h][w]b) : [h][w](a, b) =
+ map2 (map2 (\x y -> (x, y))) a b
+
+let fst 'a 'b (t: (a, b)) : a = let (x, _) = t in x
+let snd 'a 'b (t: (a, b)) : b = let (_, y) = t in y
+
+let get 'a [h] [w] (def: a) (field: [h][w]a) (y: i64) (x: i64) : a =
+ if x < 0 || y < 0 || x >= w || y >= h
+ then def
+ else field[y, x]
+
+let stencil 'a 'b [h] [w] (def: a) (field: [h][w]a) (f: a -> (a,a,a,a) -> b) : [h][w]b =
+ tabulate h (\y ->
+ tabulate w (\x ->
+ f (get def field y x)
+ (get def field (y-1) x
+ ,get def field y (x+1)
+ ,get def field (y+1) x
+ ,get def field y (x-1))))
+
+let red4 'a 'b (f: a -> b) (g: b -> b -> b) (tup: (a, a, a, a)) : b =
+ let (a, b, c, d) = tup
+ in g (g (g (f a) (f b)) (f c)) (f d)
+
+let markpits [h] [w] (field: [h][w]u8) : [h][w]u8 =
+ stencil 9 field
+ (\v env -> if red4 (\d -> v < d) (&&) env
+ then v + 1
+ else 255)
+
+let pits [h] [w] (field: [h][w]u8) : []u8 =
+ filter (!= 255) (flatten (markpits field))
+
+let grow [h] [w] (field: [h][w]u8) (basin: [h][w]bool) : [h][w]bool =
+ stencil (9, false) (zip2d field basin) (\(d, v) env -> d < 9 && (v || red4 snd (||) env))
+
+let floodfill [h] [w] (field: [h][w]u8) : *[h][w]bool =
+ let basin0 = map (map (!= 255)) (markpits field)
+ let basin1 = grow field basin0
+ let changed a b = reduce_comm (||) false (flatten (map2 (map2 (!=)) a b))
+ let (res, _) =
+ loop (a, b) = (basin0, basin1) while changed a b do
+ (b, grow field b)
+ in res
+
+let merge [h] [w] (ids: [h][w]i64) : [h][w]i64 =
+ stencil i64.highest ids
+ (\n env -> if n == i64.highest
+ then i64.highest
+ else i64.min n (red4 id i64.min env))
+
+let markbasins [h] [w] (basins: [h][w]bool) : *[h][w]i64 =
+ let numbers = unflatten h w (iota (h * w))
+ let ids0 = map2 (map2 (\n b -> if b then n else i64.highest)) numbers basins
+ let ids1 = merge ids0
+ let changed a b = reduce_comm (||) false (flatten (map2 (map2 (!=)) a b))
+ let (res, _) =
+ loop (a, b) = (ids0, ids1) while changed a b do
+ (b, merge b)
+ in res
+
+let collect_used_ids [h] [w] (ids: [h][w]i64) : []i64 =
+ let nums = flatten ids
+ let bitmap = scatter (tabulate (h * w) (\_ -> false)) nums (map (\_ -> true) nums)
+ in map snd (filter fst (zip bitmap (iota (h * w))))
+
+let ipow (base: i64) (exponent: i64) : i64 =
+ loop res = 1 for _i < exponent do res * base
+
+entry main [h] [w] (field: [h][w]u8) : (i32, i64) =
+ let ids = map (map (\n -> if n == i64.highest then -1 else n))
+ (markbasins (floodfill field))
+ let used = collect_used_ids ids
+ let sizes = map (\n -> reduce_comm (+) 0 (map ((==n) >-> i64.bool) (flatten ids))) used
+ let sizes_map = reduce_by_index (tabulate (h * w) (\_ -> 0))
+ (+) 0
+ sizes (map (\_ -> 1) sizes)
+ let count_size_cumulcount =
+ zip sizes_map (iota (h * w))
+ |> filter (\(cnt, _sz) -> cnt > 0)
+ |> reverse
+ |> (\arr -> zip arr (scan (+) 0 (map fst arr)))
+ |> take 3
+ let wanted_count =
+ map (\((cnt, _sz), cmlcnt) -> i64.max 0 (i64.min cnt (3 - (cmlcnt - cnt))))
+ count_size_cumulcount
+ let part2 = reduce_comm (*) 1
+ (map2 (\wc ((_cnt, sz), _cmlcnt) -> ipow sz wc)
+ wanted_count count_size_cumulcount)
+ in (reduce_comm (+) 0 (map i32.u8 (pits field))
+ ,part2)