diff options
author | Tom Smeding <tom.smeding@gmail.com> | 2018-08-08 23:17:38 +0200 |
---|---|---|
committer | Tom Smeding <tom.smeding@gmail.com> | 2018-08-08 23:22:48 +0200 |
commit | 9aabb53bfda096dafebfddaab8df274d82ac01b4 (patch) | |
tree | 5af82517b8119f4b792315d77d8cde03e8ebf884 /hsolve/FSu.hs | |
parent | ff7dbb854805b0f7cc2a75854c836673d7fb1ac6 (diff) |
Unsafe array operations in FSu
Another 15% faster on suvh.txt, for a total of 32%
Diffstat (limited to 'hsolve/FSu.hs')
-rw-r--r-- | hsolve/FSu.hs | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/hsolve/FSu.hs b/hsolve/FSu.hs index 1d0eea3..dc0f4c6 100644 --- a/hsolve/FSu.hs +++ b/hsolve/FSu.hs @@ -3,6 +3,7 @@ module FSu(solve) where import Control.Monad import Control.Monad.ST +import Data.Array.Base (unsafeRead, unsafeWrite) import Data.Array.ST import Data.STRef @@ -36,7 +37,7 @@ solveAt arr !i st = do res <- obtainResult arr modifySTRef' (stateResults st) (res:) else do - v <- readArray arr i + v <- unsafeRead arr i if v /= (-1) then solveAt arr (i+1) st else do @@ -46,9 +47,9 @@ solveAt arr !i st = do tryAll :: Arr s Value -> Index -> [Value] -> State s -> ST s () tryAll _ _ [] _ = return () tryAll arr !i (v:vs) st = do - writeArray arr i v + unsafeWrite arr i v solveAt arr (i+1) st - writeArray arr i (-1) + unsafeWrite arr i (-1) tryAll arr i vs st -- assumes the considered position is empty @@ -65,18 +66,18 @@ getPoss arr i st = do goRow :: Int -> Int -> ST s () goRow _ 9 = return () - goRow r j = readArray arr (9 * r + j) >>= \v -> - when (v /= (-1)) (writeArray mark v False) >> goRow r (j+1) + goRow r j = unsafeRead arr (9 * r + j) >>= \v -> + when (v /= (-1)) (unsafeWrite mark v False) >> goRow r (j+1) goCol :: Int -> Int -> ST s () goCol _ 9 = return () - goCol c j = readArray arr (9 * j + c) >>= \v -> - when (v /= (-1)) (writeArray mark v False) >> goCol c (j+1) + goCol c j = unsafeRead arr (9 * j + c) >>= \v -> + when (v /= (-1)) (unsafeWrite mark v False) >> goCol c (j+1) goBlock :: Int -> Int -> ST s () goBlock _ 9 = return () - goBlock b j = readArray arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v -> - when (v /= (-1)) (writeArray mark v False) >> goBlock b (j+1) + goBlock b j = unsafeRead arr (b + 9 * (j `quot` 3) + j `rem` 3) >>= \v -> + when (v /= (-1)) (unsafeWrite mark v False) >> goBlock b (j+1) isValid :: forall s. Arr s Value -> State s -> ST s Bool isValid arr st = do @@ -115,14 +116,14 @@ indexSetNoDups arr set st = do applyInMark :: [Int] -> ST s Bool applyInMark [] = return True applyInMark (i:is) = - readArray arr i >>= \v -> + unsafeRead arr i >>= \v -> if v == (-1) then applyInMark is else do - b <- readArray mark v + b <- unsafeRead mark v if b then return False - else writeArray mark v True >> applyInMark is + else unsafeWrite mark v True >> applyInMark is fillArrayBool :: Arr s Bool -> Int -> Int -> Bool -> ST s () fillArrayBool arr !i1 !i2 v |