From ddb57cb49a60b6173712341940195e0275ef1c9d Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 8 Aug 2018 22:58:13 +0200 Subject: Haskell solver that uses rules --- hsolve/FSu.hs | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 hsolve/FSu.hs (limited to 'hsolve/FSu.hs') diff --git a/hsolve/FSu.hs b/hsolve/FSu.hs new file mode 100644 index 0000000..1bd250a --- /dev/null +++ b/hsolve/FSu.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE BangPatterns, ScopedTypeVariables, RankNTypes #-} +module FSu(solve) where + +import Control.Monad +import Control.Monad.ST +import Data.Array.ST +import Data.STRef + + +type Value = Int -- Sudoku value +type Index = Int -- Sudoku index + +type Arr s a = STArray s Int a + +data State s = State { stateMark :: Arr s Bool, stateResults :: STRef s [[Maybe Value]] } + +solve :: [Maybe Value] -> [[Maybe Value]] +solve input = runST $ do + arr <- newListArray (0,80) $ map (maybe (-1) id) input :: ST s (Arr s Value) + mark <- newArray (0,8) False :: ST s (Arr s Bool) + results <- newSTRef [] :: ST s (STRef s [[Maybe Value]]) + solveAt arr 0 (State mark results) + readSTRef results + +obtainResult :: Arr s Value -> ST s [Maybe Value] +obtainResult arr = do + elems <- getElems arr + return $ [if v == (-1) then Nothing else Just v | v <- elems] + +solveAt :: Arr s Value -> Index -> State s -> ST s () +solveAt arr !i st = do + valid <- isValid arr st + when valid $ do + if i == 81 + then do + res <- obtainResult arr + modifySTRef' (stateResults st) (res:) + else do + v <- readArray arr i + if v /= (-1) + then solveAt arr (i+1) st + else do + poss <- getPoss arr i st + tryAll arr i poss st + +tryAll :: Arr s Value -> Index -> [Value] -> State s -> ST s () +tryAll _ _ [] _ = return () +tryAll arr !i (v:vs) st = do + writeArray arr i v + solveAt arr (i+1) st + writeArray arr i (-1) + tryAll arr i vs st + +-- assumes the considered position is empty +getPoss :: forall s. Arr s Value -> Index -> State s -> ST s [Value] +getPoss arr i st = do + fillArray mark 0 8 True + goRow (rowOf i) 0 + goCol (colOf i) 0 + goBlock (blockOrigin (blockOf i)) 0 + bs <- liftM (zip [0..8]) (getElems mark) + return $ map fst $ filter snd bs + where + mark = stateMark st + + goRow :: Int -> Int -> ST s () + goRow _ 9 = return () + goRow r j = readArray arr (9 * r + j) >>= \v -> + when (v /= (-1)) (writeArray mark v False) >> goRow r (j+1) + + goCol :: Int -> Int -> ST s () + goCol _ 9 = return () + goCol c j = readArray arr (9 * j + c) >>= \v -> + when (v /= (-1)) (writeArray mark v False) >> goCol c (j+1) + + goBlock :: Int -> Int -> ST s () + goBlock _ 9 = return () + goBlock b j = readArray arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v -> + when (v /= (-1)) (writeArray mark v False) >> goBlock b (j+1) + +isValid :: forall s. Arr s Value -> State s -> ST s Bool +isValid arr st = do + goRows 0 >>= \r1 -> if r1 + then goCols 0 >>= \r2 -> if r2 + then goBlocks 0 + else return False + else return False + where + goRows, goCols, goBlocks :: Int -> ST s Bool + goRows 9 = return True + goRows i = isValidRow arr i st >>= \r -> if r then goRows (i+1) else return False + + goCols 9 = return True + goCols i = isValidCol arr i st >>= \r -> if r then goCols (i+1) else return False + + goBlocks 9 = return True + goBlocks i = isValidBlock arr i st >>= \r -> if r then goBlocks (i+1) else return False + +isValidRow :: Arr s Value -> Index -> State s -> ST s Bool +isValidRow arr r st = indexSetNoDups arr [9 * r + i | i <- [0..8]] st + +isValidCol :: Arr s Value -> Index -> State s -> ST s Bool +isValidCol arr c st = indexSetNoDups arr [9 * i + c | i <- [0..8]] st + +isValidBlock :: Arr s Value -> Index -> State s -> ST s Bool +isValidBlock arr b st = indexSetNoDups arr [blockOrigin b + 9 * y + x| y <- [0..2], x <- [0..2]] st + +indexSetNoDups :: forall s. Arr s Value -> [Index] -> State s -> ST s Bool +indexSetNoDups arr set st = do + fillArray mark 0 8 False + applyInMark set + where + mark = stateMark st + + applyInMark :: [Int] -> ST s Bool + applyInMark [] = return True + applyInMark (i:is) = + readArray arr i >>= \v -> + if v == (-1) + then applyInMark is + else do + b <- readArray mark v + if b + then return False + else writeArray mark v True >> applyInMark is + +fillArray :: Arr s a -> Int -> Int -> a -> ST s () +fillArray arr !i1 !i2 v + | i1 <= i2 = do + writeArray arr i1 v + fillArray arr (i1 + 1) i2 v + | otherwise = return () + +rowOf :: Index -> Index +rowOf i = i `quot` 9 + +colOf :: Index -> Index +colOf i = i `rem` 9 + +blockOf :: Index -> Index +blockOf i = 3 * (i `quot` 27) + (i `rem` 9) `quot` 3 + +blockOrigin :: Index -> Index +blockOrigin 0 = 0 +blockOrigin 1 = 3 +blockOrigin 2 = 6 +blockOrigin 3 = 27 +blockOrigin 4 = 30 +blockOrigin 5 = 33 +blockOrigin 6 = 54 +blockOrigin 7 = 57 +blockOrigin 8 = 60 -- cgit v1.2.3-70-g09d2