summaryrefslogtreecommitdiff
path: root/hsolve/FSu.hs
blob: 1d0eea3b6ba6ea10d674754edfd908b95d35dc36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
{-# LANGUAGE BangPatterns, ScopedTypeVariables, RankNTypes #-}
module FSu(solve) where

import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.STRef


type Value = Int  -- Sudoku value
type Index = Int  -- Sudoku index

type Arr s a = STUArray s Int a

data State s = State { stateMark :: Arr s Bool, stateResults :: STRef s [[Maybe Value]] }

solve :: [Maybe Value] -> [[Maybe Value]]
solve input = runST $ do
    arr <- newListArray (0,80) $ map (maybe (-1) id) input :: ST s (Arr s Value)
    mark <- newArray (0,8) False :: ST s (Arr s Bool)
    results <- newSTRef [] :: ST s (STRef s [[Maybe Value]])
    solveAt arr 0 (State mark results)
    readSTRef results

obtainResult :: Arr s Value -> ST s [Maybe Value]
obtainResult arr = do
    elems <- getElems arr
    return $ [if v == (-1) then Nothing else Just v | v <- elems]

solveAt :: Arr s Value -> Index -> State s -> ST s ()
solveAt arr !i st = do
    valid <- isValid arr st
    when valid $ do
        if i == 81
        then do
            res <- obtainResult arr
            modifySTRef' (stateResults st) (res:)
        else do
            v <- readArray arr i
            if v /= (-1)
            then solveAt arr (i+1) st
            else do
                poss <- getPoss arr i st
                tryAll arr i poss st

tryAll :: Arr s Value -> Index -> [Value] -> State s -> ST s ()
tryAll _ _ [] _ = return ()
tryAll arr !i (v:vs) st = do
    writeArray arr i v
    solveAt arr (i+1) st
    writeArray arr i (-1)
    tryAll arr i vs st

-- assumes the considered position is empty
getPoss :: forall s. Arr s Value -> Index -> State s -> ST s [Value]
getPoss arr i st = do
    fillArrayBool mark 0 8 True
    goRow (rowOf i) 0
    goCol (colOf i) 0
    goBlock (blockOrigin (blockOf i)) 0
    bs <- liftM (zip [0..8]) (getElems mark)
    return $ map fst $ filter snd bs
  where
    mark = stateMark st

    goRow :: Int -> Int -> ST s ()
    goRow _ 9 = return ()
    goRow r j = readArray arr (9 * r + j) >>= \v ->
        when (v /= (-1)) (writeArray mark v False) >> goRow r (j+1)

    goCol :: Int -> Int -> ST s ()
    goCol _ 9 = return ()
    goCol c j = readArray arr (9 * j + c) >>= \v ->
        when (v /= (-1)) (writeArray mark v False) >> goCol c (j+1)

    goBlock :: Int -> Int -> ST s ()
    goBlock _ 9 = return ()
    goBlock b j = readArray arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v ->
        when (v /= (-1)) (writeArray mark v False) >> goBlock b (j+1)

isValid :: forall s. Arr s Value -> State s -> ST s Bool
isValid arr st = do
    goRows 0 >>= \r1 -> if r1
        then goCols 0 >>= \r2 -> if r2
            then goBlocks 0
            else return False
        else return False
  where
    goRows, goCols, goBlocks :: Int -> ST s Bool
    goRows 9 = return True
    goRows i = isValidRow arr i st >>= \r -> if r then goRows (i+1) else return False

    goCols 9 = return True
    goCols i = isValidCol arr i st >>= \r -> if r then goCols (i+1) else return False

    goBlocks 9 = return True
    goBlocks i = isValidBlock arr i st >>= \r -> if r then goBlocks (i+1) else return False

isValidRow :: Arr s Value -> Index -> State s -> ST s Bool
isValidRow arr r st = indexSetNoDups arr [9 * r + i | i <- [0..8]] st

isValidCol :: Arr s Value -> Index -> State s -> ST s Bool
isValidCol arr c st = indexSetNoDups arr [9 * i + c | i <- [0..8]] st

isValidBlock :: Arr s Value -> Index -> State s -> ST s Bool
isValidBlock arr b st = indexSetNoDups arr [blockOrigin b + 9 * y + x| y <- [0..2], x <- [0..2]] st

indexSetNoDups :: forall s. Arr s Value -> [Index] -> State s -> ST s Bool
indexSetNoDups arr set st = do
    fillArrayBool mark 0 8 False
    applyInMark set
  where
    mark = stateMark st

    applyInMark :: [Int] -> ST s Bool
    applyInMark [] = return True
    applyInMark (i:is) =
        readArray arr i >>= \v ->
            if v == (-1)
            then applyInMark is
            else do
                b <- readArray mark v
                if b
                then return False
                else writeArray mark v True >> applyInMark is

fillArrayBool :: Arr s Bool -> Int -> Int -> Bool -> ST s ()
fillArrayBool arr !i1 !i2 v
    | i1 <= i2 = unsafeWrite arr i1 v >> fillArrayBool arr (i1 + 1) i2 v
    | otherwise = return ()

rowOf :: Index -> Index
rowOf i = i `quot` 9

colOf :: Index -> Index
colOf i = i `rem` 9

blockOf :: Index -> Index
blockOf i = 3 * (i `quot` 27) + (i `rem` 9) `quot` 3

blockOrigin :: Index -> Index
blockOrigin 0 = 0
blockOrigin 1 = 3
blockOrigin 2 = 6
blockOrigin 3 = 27
blockOrigin 4 = 30
blockOrigin 5 = 33
blockOrigin 6 = 54
blockOrigin 7 = 57
blockOrigin 8 = 60