summaryrefslogtreecommitdiff
path: root/hsolve/FSu.hs
diff options
context:
space:
mode:
Diffstat (limited to 'hsolve/FSu.hs')
-rw-r--r--hsolve/FSu.hs152
1 files changed, 152 insertions, 0 deletions
diff --git a/hsolve/FSu.hs b/hsolve/FSu.hs
new file mode 100644
index 0000000..1bd250a
--- /dev/null
+++ b/hsolve/FSu.hs
@@ -0,0 +1,152 @@
+{-# LANGUAGE BangPatterns, ScopedTypeVariables, RankNTypes #-}
+module FSu(solve) where
+
+import Control.Monad
+import Control.Monad.ST
+import Data.Array.ST
+import Data.STRef
+
+
+type Value = Int -- Sudoku value
+type Index = Int -- Sudoku index
+
+type Arr s a = STArray s Int a
+
+data State s = State { stateMark :: Arr s Bool, stateResults :: STRef s [[Maybe Value]] }
+
+solve :: [Maybe Value] -> [[Maybe Value]]
+solve input = runST $ do
+ arr <- newListArray (0,80) $ map (maybe (-1) id) input :: ST s (Arr s Value)
+ mark <- newArray (0,8) False :: ST s (Arr s Bool)
+ results <- newSTRef [] :: ST s (STRef s [[Maybe Value]])
+ solveAt arr 0 (State mark results)
+ readSTRef results
+
+obtainResult :: Arr s Value -> ST s [Maybe Value]
+obtainResult arr = do
+ elems <- getElems arr
+ return $ [if v == (-1) then Nothing else Just v | v <- elems]
+
+solveAt :: Arr s Value -> Index -> State s -> ST s ()
+solveAt arr !i st = do
+ valid <- isValid arr st
+ when valid $ do
+ if i == 81
+ then do
+ res <- obtainResult arr
+ modifySTRef' (stateResults st) (res:)
+ else do
+ v <- readArray arr i
+ if v /= (-1)
+ then solveAt arr (i+1) st
+ else do
+ poss <- getPoss arr i st
+ tryAll arr i poss st
+
+tryAll :: Arr s Value -> Index -> [Value] -> State s -> ST s ()
+tryAll _ _ [] _ = return ()
+tryAll arr !i (v:vs) st = do
+ writeArray arr i v
+ solveAt arr (i+1) st
+ writeArray arr i (-1)
+ tryAll arr i vs st
+
+-- assumes the considered position is empty
+getPoss :: forall s. Arr s Value -> Index -> State s -> ST s [Value]
+getPoss arr i st = do
+ fillArray mark 0 8 True
+ goRow (rowOf i) 0
+ goCol (colOf i) 0
+ goBlock (blockOrigin (blockOf i)) 0
+ bs <- liftM (zip [0..8]) (getElems mark)
+ return $ map fst $ filter snd bs
+ where
+ mark = stateMark st
+
+ goRow :: Int -> Int -> ST s ()
+ goRow _ 9 = return ()
+ goRow r j = readArray arr (9 * r + j) >>= \v ->
+ when (v /= (-1)) (writeArray mark v False) >> goRow r (j+1)
+
+ goCol :: Int -> Int -> ST s ()
+ goCol _ 9 = return ()
+ goCol c j = readArray arr (9 * j + c) >>= \v ->
+ when (v /= (-1)) (writeArray mark v False) >> goCol c (j+1)
+
+ goBlock :: Int -> Int -> ST s ()
+ goBlock _ 9 = return ()
+ goBlock b j = readArray arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v ->
+ when (v /= (-1)) (writeArray mark v False) >> goBlock b (j+1)
+
+isValid :: forall s. Arr s Value -> State s -> ST s Bool
+isValid arr st = do
+ goRows 0 >>= \r1 -> if r1
+ then goCols 0 >>= \r2 -> if r2
+ then goBlocks 0
+ else return False
+ else return False
+ where
+ goRows, goCols, goBlocks :: Int -> ST s Bool
+ goRows 9 = return True
+ goRows i = isValidRow arr i st >>= \r -> if r then goRows (i+1) else return False
+
+ goCols 9 = return True
+ goCols i = isValidCol arr i st >>= \r -> if r then goCols (i+1) else return False
+
+ goBlocks 9 = return True
+ goBlocks i = isValidBlock arr i st >>= \r -> if r then goBlocks (i+1) else return False
+
+isValidRow :: Arr s Value -> Index -> State s -> ST s Bool
+isValidRow arr r st = indexSetNoDups arr [9 * r + i | i <- [0..8]] st
+
+isValidCol :: Arr s Value -> Index -> State s -> ST s Bool
+isValidCol arr c st = indexSetNoDups arr [9 * i + c | i <- [0..8]] st
+
+isValidBlock :: Arr s Value -> Index -> State s -> ST s Bool
+isValidBlock arr b st = indexSetNoDups arr [blockOrigin b + 9 * y + x| y <- [0..2], x <- [0..2]] st
+
+indexSetNoDups :: forall s. Arr s Value -> [Index] -> State s -> ST s Bool
+indexSetNoDups arr set st = do
+ fillArray mark 0 8 False
+ applyInMark set
+ where
+ mark = stateMark st
+
+ applyInMark :: [Int] -> ST s Bool
+ applyInMark [] = return True
+ applyInMark (i:is) =
+ readArray arr i >>= \v ->
+ if v == (-1)
+ then applyInMark is
+ else do
+ b <- readArray mark v
+ if b
+ then return False
+ else writeArray mark v True >> applyInMark is
+
+fillArray :: Arr s a -> Int -> Int -> a -> ST s ()
+fillArray arr !i1 !i2 v
+ | i1 <= i2 = do
+ writeArray arr i1 v
+ fillArray arr (i1 + 1) i2 v
+ | otherwise = return ()
+
+rowOf :: Index -> Index
+rowOf i = i `quot` 9
+
+colOf :: Index -> Index
+colOf i = i `rem` 9
+
+blockOf :: Index -> Index
+blockOf i = 3 * (i `quot` 27) + (i `rem` 9) `quot` 3
+
+blockOrigin :: Index -> Index
+blockOrigin 0 = 0
+blockOrigin 1 = 3
+blockOrigin 2 = 6
+blockOrigin 3 = 27
+blockOrigin 4 = 30
+blockOrigin 5 = 33
+blockOrigin 6 = 54
+blockOrigin 7 = 57
+blockOrigin 8 = 60