{-# LANGUAGE BangPatterns, ScopedTypeVariables, RankNTypes #-} module FSu(solve) where import Control.Monad import Control.Monad.ST import Data.Array.Base (unsafeRead, unsafeWrite) import Data.Array.ST import Data.STRef type Value = Int -- Sudoku value type Index = Int -- Sudoku index type Arr s a = STUArray s Int a data State s = State { stateMark :: Arr s Bool, stateResults :: STRef s [[Maybe Value]] } solve :: [Maybe Value] -> [[Maybe Value]] solve input = runST $ do arr <- newListArray (0,80) $ map (maybe (-1) id) input :: ST s (Arr s Value) mark <- newArray (0,8) False :: ST s (Arr s Bool) results <- newSTRef [] :: ST s (STRef s [[Maybe Value]]) solveAt arr 0 (State mark results) readSTRef results obtainResult :: Arr s Value -> ST s [Maybe Value] obtainResult arr = do elems <- getElems arr return $ [if v == (-1) then Nothing else Just v | v <- elems] solveAt :: Arr s Value -> Index -> State s -> ST s () solveAt arr !i st = do valid <- isValid arr st when valid $ do if i == 81 then do res <- obtainResult arr modifySTRef' (stateResults st) (res:) else do v <- unsafeRead arr i if v /= (-1) then solveAt arr (i+1) st else do poss <- getPoss arr i st tryAll arr i poss st tryAll :: Arr s Value -> Index -> [Value] -> State s -> ST s () tryAll _ _ [] _ = return () tryAll arr !i (v:vs) st = do unsafeWrite arr i v solveAt arr (i+1) st unsafeWrite arr i (-1) tryAll arr i vs st -- assumes the considered position is empty getPoss :: forall s. Arr s Value -> Index -> State s -> ST s [Value] getPoss arr i st = do fillArrayBool mark 0 8 True goRow (rowOf i) 0 goCol (colOf i) 0 goBlock (blockOrigin (blockOf i)) 0 bs <- liftM (zip [0..8]) (getElems mark) return $ map fst $ filter snd bs where mark = stateMark st goRow :: Int -> Int -> ST s () goRow _ 9 = return () goRow r j = unsafeRead arr (9 * r + j) >>= \v -> when (v /= (-1)) (unsafeWrite mark v False) >> goRow r (j+1) goCol :: Int -> Int -> ST s () goCol _ 9 = return () goCol c j = unsafeRead arr (9 * j + c) >>= \v -> when (v /= (-1)) (unsafeWrite mark v False) >> goCol c (j+1) goBlock :: Int -> Int -> ST s () goBlock _ 9 = return () goBlock b j = unsafeRead arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v -> when (v /= (-1)) (unsafeWrite mark v False) >> goBlock b (j+1) isValid :: forall s. Arr s Value -> State s -> ST s Bool isValid arr st = do goRows 0 >>= \r1 -> if r1 then goCols 0 >>= \r2 -> if r2 then goBlocks 0 else return False else return False where goRows, goCols, goBlocks :: Int -> ST s Bool goRows 9 = return True goRows i = isValidRow arr i st >>= \r -> if r then goRows (i+1) else return False goCols 9 = return True goCols i = isValidCol arr i st >>= \r -> if r then goCols (i+1) else return False goBlocks 9 = return True goBlocks i = isValidBlock arr i st >>= \r -> if r then goBlocks (i+1) else return False isValidRow :: Arr s Value -> Index -> State s -> ST s Bool isValidRow arr r st = indexSetNoDups arr [9 * r + i | i <- [0..8]] st isValidCol :: Arr s Value -> Index -> State s -> ST s Bool isValidCol arr c st = indexSetNoDups arr [9 * i + c | i <- [0..8]] st isValidBlock :: Arr s Value -> Index -> State s -> ST s Bool isValidBlock arr b st = indexSetNoDups arr [blockOrigin b + 9 * y + x | y <- [0..2], x <- [0..2]] st indexSetNoDups :: forall s. Arr s Value -> [Index] -> State s -> ST s Bool indexSetNoDups arr set st = do fillArrayBool mark 0 8 False applyInMark set where mark = stateMark st applyInMark :: [Int] -> ST s Bool applyInMark [] = return True applyInMark (i:is) = unsafeRead arr i >>= \v -> if v == (-1) then applyInMark is else do b <- unsafeRead mark v if b then return False else unsafeWrite mark v True >> applyInMark is fillArrayBool :: Arr s Bool -> Int -> Int -> Bool -> ST s () fillArrayBool arr !i1 !i2 v | i1 <= i2 = unsafeWrite arr i1 v >> fillArrayBool arr (i1 + 1) i2 v | otherwise = return () rowOf :: Index -> Index rowOf i = i `quot` 9 colOf :: Index -> Index colOf i = i `rem` 9 blockOf :: Index -> Index blockOf i = 3 * (i `quot` 27) + (i `rem` 9) `quot` 3 blockOrigin :: Index -> Index blockOrigin 0 = 0 blockOrigin 1 = 3 blockOrigin 2 = 6 blockOrigin 3 = 27 blockOrigin 4 = 30 blockOrigin 5 = 33 blockOrigin 6 = 54 blockOrigin 7 = 57 blockOrigin 8 = 60