{-# LANGUAGE BangPatterns, TypeApplications #-} module Main where import Control.Concurrent (getNumCapabilities, forkIO, forkOn) import Control.Concurrent.Chan (newChan, writeChan, readChan) import Control.Monad (forM_) import qualified Data.Array.Unboxed as A import qualified Data.Array.IO as IOA import Debug.Trace import Input ceilDiv :: Int -> Int -> Int ceilDiv a b = (a + b - 1) `div` b lastDigit :: Int -> Int lastDigit = abs . (`rem` 10) divideList :: Int -> [a] -> [[a]] divideList n l = blockBy (length l `div` n) l blockBy :: Int -> [a] -> [[a]] blockBy n l = let (pre, post) = splitAt n l in if null post then [pre] else pre : blockBy n post type Array = A.UArray Int prefixSums :: Array Int -> Array Int prefixSums input = A.listArray (0, snd (A.bounds input) + 1) (scanl (+) 0 (A.elems input)) phase :: Array Int -> IO (Array Int) phase !input = let end = snd (A.bounds input) prefixes = prefixSums input rangeSum from to = prefixes A.! to - prefixes A.! from intervalSum acc offset step sign | offset + step > end = if offset > end then acc else acc + sign * rangeSum offset (end + 1) | otherwise = let acc' = acc + sign * rangeSum offset (offset + step) in intervalSum acc' (offset + 2 * step) step (negate sign) in do chan <- newChan destArr <- IOA.newArray (0, end) 0 :: IO (IOA.IOUArray Int Int) -- nprocs <- getNumCapabilities -- traceM (show nprocs ++ " threads") -- let computePart idx = -- -- sum_{i=1}^k 1/i ~= ln k + gamma -- -- Find values k0=1,k1,...,kn=M such that sum_{i=kj}^{k{j+1}} 1/i = 1/n sum_{i=1}^M 1/i -- -- (n = nprocs, M = end+1) -- -- Approximately values such that ln k{j+1} - ln kj = 1/n ln M -- -- Claim kj = M^(j/n) -- -- Then ln k{j+1} - ln kj = (j+1)/n ln M - j/n ln M = 1/n ln M \qed -- let kvals = [round @Double (fromIntegral (end+1) ** (fromIntegral j / fromIntegral nprocs)) -- | j <- [0..nprocs-1]] -- ++ [end+1] -- intervals = (\l -> zip l (tail l)) (map pred kvals) -- in forkOn idx $ do -- let (low, high) = intervals !! idx -- traceM ("thread " ++ show idx ++ ": writing [" ++ show low ++ ".." ++ show high ++ "]") -- forM_ [low..high] $ \i -> IOA.writeArray destArr i (lastDigit (intervalSum 0 i (i+1) 1)) -- traceM ("thread " ++ show idx ++ ": writeChan") -- writeChan chan () nprocs <- (*10) <$> getNumCapabilities traceM (show nprocs ++ " threads") let numPerPart = (end + 1) `ceilDiv` nprocs let computePart idx = let low = max 2 (idx * numPerPart) high = min ((idx + 1) * numPerPart - 1) end in forkIO $ do traceM ("thread " ++ show idx ++ ": writing [" ++ show low ++ ".." ++ show high ++ "]") forM_ [low..high] $ \i -> IOA.writeArray destArr i (lastDigit (intervalSum 0 i (i+1) 1)) traceM ("thread " ++ show idx ++ ": writeChan") writeChan chan () IOA.writeArray destArr 0 (lastDigit (intervalSum 0 0 1 1)) IOA.writeArray destArr 1 (lastDigit (intervalSum 0 1 2 1)) mapM_ computePart [0..nprocs-1] -- spawn the threads sequence_ (replicate nprocs (readChan chan)) -- wait for them to finish IOA.freeze destArr fft :: Int -> [Int] -> IO [Int] fft times list = A.elems <$> go times (A.listArray (0, length list - 1) list) where go 0 arr = return arr go n arr = phase arr >>= go (n - 1) main :: IO () main = do inpString <- head <$> getInput 16 let inpList = map (read . pure) inpString concatMap show . take 8 <$> fft 100 inpList >>= putStrLn let messageOffset = read (take 7 inpString) -- Note: Part 2 takes 3m20s for me concatMap show . take 8 . drop messageOffset <$> fft 100 (concat (replicate 10000 inpList)) >>= putStrLn