blob: 7429da28a8baed9c822f08e74d2cab60131e2ae2 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
|
{-# LANGUAGE BangPatterns, ScopedTypeVariables, RankNTypes #-}
module FSu(solve) where
import Control.Monad
import Control.Monad.ST
import Data.Array.Base (unsafeRead, unsafeWrite)
import Data.Array.ST
import Data.STRef
type Value = Int -- Sudoku value
type Index = Int -- Sudoku index
type Arr s a = STUArray s Int a
data State s = State { stateMark :: Arr s Bool, stateResults :: STRef s [[Maybe Value]] }
solve :: [Maybe Value] -> [[Maybe Value]]
solve input = runST $ do
arr <- newListArray (0,80) $ map (maybe (-1) id) input :: ST s (Arr s Value)
mark <- newArray (0,8) False :: ST s (Arr s Bool)
results <- newSTRef [] :: ST s (STRef s [[Maybe Value]])
solveAt arr 0 (State mark results)
readSTRef results
obtainResult :: Arr s Value -> ST s [Maybe Value]
obtainResult arr = do
elems <- getElems arr
return $ [if v == (-1) then Nothing else Just v | v <- elems]
solveAt :: Arr s Value -> Index -> State s -> ST s ()
solveAt arr !i st = do
valid <- isValid arr st
when valid $ do
if i == 81
then do
res <- obtainResult arr
modifySTRef' (stateResults st) (res:)
else do
v <- unsafeRead arr i
if v /= (-1)
then solveAt arr (i+1) st
else do
poss <- getPoss arr i st
tryAll arr i poss st
tryAll :: Arr s Value -> Index -> [Value] -> State s -> ST s ()
tryAll _ _ [] _ = return ()
tryAll arr !i (v:vs) st = do
unsafeWrite arr i v
solveAt arr (i+1) st
unsafeWrite arr i (-1)
tryAll arr i vs st
-- assumes the considered position is empty
getPoss :: forall s. Arr s Value -> Index -> State s -> ST s [Value]
getPoss arr i st = do
fillArrayBool mark 0 8 True
goRow (rowOf i) 0
goCol (colOf i) 0
goBlock (blockOrigin (blockOf i)) 0
bs <- liftM (zip [0..8]) (getElems mark)
return $ map fst $ filter snd bs
where
mark = stateMark st
goRow :: Int -> Int -> ST s ()
goRow _ 9 = return ()
goRow r j = unsafeRead arr (9 * r + j) >>= \v ->
when (v /= (-1)) (unsafeWrite mark v False) >> goRow r (j+1)
goCol :: Int -> Int -> ST s ()
goCol _ 9 = return ()
goCol c j = unsafeRead arr (9 * j + c) >>= \v ->
when (v /= (-1)) (unsafeWrite mark v False) >> goCol c (j+1)
goBlock :: Int -> Int -> ST s ()
goBlock _ 9 = return ()
goBlock b j = unsafeRead arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v ->
when (v /= (-1)) (unsafeWrite mark v False) >> goBlock b (j+1)
isValid :: forall s. Arr s Value -> State s -> ST s Bool
isValid arr st = do
goRows 0 >>= \r1 -> if r1
then goCols 0 >>= \r2 -> if r2
then goBlocks 0
else return False
else return False
where
goRows, goCols, goBlocks :: Int -> ST s Bool
goRows 9 = return True
goRows i = isValidRow arr i st >>= \r -> if r then goRows (i+1) else return False
goCols 9 = return True
goCols i = isValidCol arr i st >>= \r -> if r then goCols (i+1) else return False
goBlocks 9 = return True
goBlocks i = isValidBlock arr i st >>= \r -> if r then goBlocks (i+1) else return False
isValidRow :: Arr s Value -> Index -> State s -> ST s Bool
isValidRow arr r st = indexSetNoDups arr [9 * r + i | i <- [0..8]] st
isValidCol :: Arr s Value -> Index -> State s -> ST s Bool
isValidCol arr c st = indexSetNoDups arr [9 * i + c | i <- [0..8]] st
isValidBlock :: Arr s Value -> Index -> State s -> ST s Bool
isValidBlock arr b st = indexSetNoDups arr [blockOrigin b + 9 * y + x | y <- [0..2], x <- [0..2]] st
indexSetNoDups :: forall s. Arr s Value -> [Index] -> State s -> ST s Bool
indexSetNoDups arr set st = do
fillArrayBool mark 0 8 False
applyInMark set
where
mark = stateMark st
applyInMark :: [Int] -> ST s Bool
applyInMark [] = return True
applyInMark (i:is) =
unsafeRead arr i >>= \v ->
if v == (-1)
then applyInMark is
else do
b <- unsafeRead mark v
if b
then return False
else unsafeWrite mark v True >> applyInMark is
fillArrayBool :: Arr s Bool -> Int -> Int -> Bool -> ST s ()
fillArrayBool arr !i1 !i2 v
| i1 <= i2 = unsafeWrite arr i1 v >> fillArrayBool arr (i1 + 1) i2 v
| otherwise = return ()
rowOf :: Index -> Index
rowOf i = i `quot` 9
colOf :: Index -> Index
colOf i = i `rem` 9
blockOf :: Index -> Index
blockOf i = 3 * (i `quot` 27) + (i `rem` 9) `quot` 3
blockOrigin :: Index -> Index
blockOrigin 0 = 0
blockOrigin 1 = 3
blockOrigin 2 = 6
blockOrigin 3 = 27
blockOrigin 4 = 30
blockOrigin 5 = 33
blockOrigin 6 = 54
blockOrigin 7 = 57
blockOrigin 8 = 60
|