summaryrefslogtreecommitdiff
path: root/2019/16multi.hs
blob: 87046a55ecbe7191ac1abcd382821cbcd64cf792 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
{-# LANGUAGE BangPatterns, TypeApplications #-}
module Main where

import Control.Concurrent (getNumCapabilities, forkIO, forkOn)
import Control.Concurrent.Chan (newChan, writeChan, readChan)
import Control.Monad (forM_)
import qualified Data.Array.Unboxed as A
import qualified Data.Array.IO as IOA

import Debug.Trace

import Input


ceilDiv :: Int -> Int -> Int
ceilDiv a b = (a + b - 1) `div` b

lastDigit :: Int -> Int
lastDigit = abs . (`rem` 10)

divideList :: Int -> [a] -> [[a]]
divideList n l = blockBy (length l `div` n) l

blockBy :: Int -> [a] -> [[a]]
blockBy n l =
    let (pre, post) = splitAt n l
    in if null post then [pre] else pre : blockBy n post

type Array = A.UArray Int

prefixSums :: Array Int -> Array Int
prefixSums input = A.listArray (0, snd (A.bounds input) + 1) (scanl (+) 0 (A.elems input))

phase :: Array Int -> IO (Array Int)
phase !input =
    let end = snd (A.bounds input)
        prefixes = prefixSums input
        rangeSum from to = prefixes A.! to - prefixes A.! from
        intervalSum acc offset step sign
            | offset + step > end =
                if offset > end
                    then acc
                    else acc + sign * rangeSum offset (end + 1)
            | otherwise =
                let acc' = acc + sign * rangeSum offset (offset + step)
                in intervalSum acc' (offset + 2 * step) step (negate sign)
    in do
        chan <- newChan
        destArr <- IOA.newArray (0, end) 0 :: IO (IOA.IOUArray Int Int)

        -- nprocs <- getNumCapabilities
        -- traceM (show nprocs ++ " threads")
        -- let computePart idx =
        --         -- sum_{i=1}^k 1/i ~= ln k + gamma
        --         -- Find values k0=1,k1,...,kn=M such that sum_{i=kj}^{k{j+1}} 1/i = 1/n sum_{i=1}^M 1/i
        --         --   (n = nprocs, M = end+1)
        --         -- Approximately values such that ln k{j+1} - ln kj = 1/n ln M
        --         -- Claim kj = M^(j/n)
        --         -- Then ln k{j+1} - ln kj = (j+1)/n ln M - j/n ln M = 1/n ln M   \qed
        --         let kvals = [round @Double (fromIntegral (end+1) ** (fromIntegral j / fromIntegral nprocs))
        --                     | j <- [0..nprocs-1]]
        --                     ++ [end+1]
        --             intervals = (\l -> zip l (tail l)) (map pred kvals)
        --         in forkOn idx $ do
        --             let (low, high) = intervals !! idx
        --             traceM ("thread " ++ show idx ++ ": writing [" ++ show low ++ ".." ++ show high ++ "]")
        --             forM_ [low..high] $ \i -> IOA.writeArray destArr i (lastDigit (intervalSum 0 i (i+1) 1))
        --             traceM ("thread " ++ show idx ++ ": writeChan")
        --             writeChan chan ()

        nprocs <- (*10) <$> getNumCapabilities
        traceM (show nprocs ++ " threads")
        let numPerPart = (end + 1) `ceilDiv` nprocs
        let computePart idx =
                let low = max 2 (idx * numPerPart)
                    high = min ((idx + 1) * numPerPart - 1) end
                in forkIO $ do
                    traceM ("thread " ++ show idx ++ ": writing [" ++ show low ++ ".." ++ show high ++ "]")
                    forM_ [low..high] $ \i -> IOA.writeArray destArr i (lastDigit (intervalSum 0 i (i+1) 1))
                    traceM ("thread " ++ show idx ++ ": writeChan")
                    writeChan chan ()
        IOA.writeArray destArr 0 (lastDigit (intervalSum 0 0 1 1))
        IOA.writeArray destArr 1 (lastDigit (intervalSum 0 1 2 1))

        mapM_ computePart [0..nprocs-1]  -- spawn the threads
        sequence_ (replicate nprocs (readChan chan))  -- wait for them to finish

        IOA.freeze destArr

fft :: Int -> [Int] -> IO [Int]
fft times list = A.elems <$> go times (A.listArray (0, length list - 1) list)
  where
    go 0 arr = return arr
    go n arr = phase arr >>= go (n - 1)

main :: IO ()
main = do
    inpString <- head <$> getInput 16

    let inpList = map (read . pure) inpString
    concatMap show . take 8 <$> fft 100 inpList >>= putStrLn

    let messageOffset = read (take 7 inpString)

    -- Note: Part 2 takes 3m20s for me
    concatMap show . take 8 . drop messageOffset <$> fft 100 (concat (replicate 10000 inpList)) >>= putStrLn