import Data.List
import System.Random
{- Data Types for the game -}
data Player = X | O deriving (Eq, Show)
type Board = [[Maybe Player]]
data Node = Node { wins :: Int -- how many wins
, plays :: Int -- how many plays
, player :: Player -- which player is it
, state :: Board -- the state of the board
, children :: [Node] -- the children
}
| Empty -- if this is an unexplored node
| Forbidden -- if this is a forbidden node
deriving (Eq, Show)
-- | Data type to zip the tree up
data Choice = Place { winsP :: Int -- the info for the current node
, playsP :: Int
, playerP :: Player
, stateP :: Board
, siblingsP :: ([Node], [Node]) -- the paths not taken
}
deriving Show
type Thread = [Choice]
type Zipper = (Thread, Node)
{- Tic-Tac-Toe functions -}
-- | change an int to a coord on the board
place2coord :: Int -> (Int, Int)
place2coord x = (x `div` 3, x `mod` 3)
-- | did anyone win?
win :: Board -> Bool
win b = any full $ b ++ (transpose b) ++ (diags b)
where
full [Nothing, _, _] = False
full [x1, x2, x3] = x1 == x2 && x2 == x3
diags [[x1, _, x2],
[_, x3, _],
[x4, _, x5]] = [[x1,x3,x5], [x2,x3,x4]]
-- | is it a draw?
draw :: Board -> Bool
draw b = (null . possibleMoves) b && (not.win) b
-- | have we reached a goal state?
goal :: Board -> Bool
goal b = win b || draw b
-- | list of possible moves to make
possibleMoves :: Board -> [Int]
possibleMoves b = map snd
$ filter ((==Nothing) . fst)
$ zip (concat b) [0..]
-- | return the next player
nextPlayer :: Player -> Player
nextPlayer X = O
nextPlayer O = X
-- | perform a move
move :: Int -> Player -> Board -> Board
move pos p b = take x b ++ (newline : drop (x+1) b)
where (x, y) = place2coord pos
newline = take y bi ++ (Just p) : drop (y+1) bi
bi = b !! x
{- Utility functions -}
-- | calculates the upper bound of the confidence interval
confidence :: Int -> Int -> Int -> Double
confidence wins ni n = hi
where hi = mu + interval
interval = sqrt $ (2 * log n') / ni'
mu = wins' / ni'
wins' = fromIntegral wins
ni' = fromIntegral ni
n' = fromIntegral n
-- | returns the index of the node with maximum confidence
maxConfidence :: [Node] -> Int
maxConfidence ns = snd.head
$ sort
$ zip (map conf ns) [0..]
where
conf Forbidden = 1000
conf Empty = 1000
conf n = negate $ confidence (wins n) (plays n) total
total = sum $ map plays
$ filter valid ns
valid ni = ni /= Forbidden && ni /= Empty
-- | is there any unexplored children?
anyEmpty :: Node -> Bool
anyEmpty n = any (==Empty) $ children n
-- | is this a dead end?
allForbidden :: Node -> Bool
allForbidden n = all (==Forbidden) $ children n
-- | zip down the tree
nextStep :: Int -> Node -> Choice
nextStep pos n = Place (wins n) (plays n) (player n) (state n) (cs', cs'')
where cs' = take pos $ children n
cs'' = drop (pos + 1) $ children n
-- | pick a list element at random and return the index
pickRandom :: Eq a => StdGen -> [a] -> a -> (Int, StdGen)
pickRandom g xs x | null choices = error "No choices to be made!"
| otherwise = (pos, g')
where pos = snd $ choices !! idx
(idx, g') = randomR (0, length choices - 1) g
choices = filter ((==x) . fst) $ zip xs [0..]
-- | pick a child at random
randomChild :: StdGen -> [Node] -> (Int, StdGen)
randomChild g ns = pickRandom g ns Empty
-- | choose a move at random
randomMove :: StdGen -> Board -> (Int, StdGen)
randomMove g b = pickRandom g (concat b) Nothing
{- MCTS functions -}
-- | select the next node to expand
select :: Zipper -> Zipper
-- if we reach an empty or forbidden, something went really wrong
select z@(t, Empty) = error "Select reached empty node"
select z@(t, Forbidden) = error "Select reached forbidden node"
select z@(t, n) | anyEmpty n = z -- expand this node
| allForbidden n = z -- nothing more to be done
| otherwise = select (t', n') -- move forward
where
n' = children n !! idx
t' = step : t
step = nextStep idx n
idx = maxConfidence $ children n
-- | expand the selected node to a random empty children
expansion :: StdGen -> Zipper -> (Zipper, StdGen)
expansion g z@(t, n) = ((t', n'), g')
where t' = (nextStep pos n) : t
n' = Node 0 0 (nextPlayer $ player n) s nextChildren
-- next state of the board
s = move pos (player n) (state n)
-- next children with forbidden positions marked
nextChildren = map avail $ concat s
avail Nothing = Empty
avail _ = Forbidden
-- the child of n to be expanded
(pos, g') = randomChild g $ children n
-- | simulate the remainder of the game at random, without expanding the tree
simulation :: StdGen -> Player -> Board -> (Int, Board, StdGen)
simulation g p b | win b = (score, b, g)
| draw b = ( 0, b, g)
| otherwise = simulation g' (nextPlayer p) b'
where
score = if p == X then -1 else 1
b' = move pos p b
(pos, g') = randomMove g b
-- | propagates the result
backpropagation :: Int -> Zipper -> Node
backpropagation score ([], n) = n
backpropagation score (t:ts, n) = backpropagation score (ts, n'')
where
-- zip up
n'' = Node (winsP t + wins'') (playsP t + 1) (playerP t) (stateP t) ns
ns = s1 ++ n' : s2
(s1, s2) = siblingsP t
n' = n {wins = wins n + wins', plays = plays n + 1}
wins' = winner score $ player n
wins'' = winner score $ playerP t
-- | let's count a draw as a win
winner :: Int -> Player -> Int
winner 0 _ = 1
winner 1 X = 1
winner (-1) O = 1
winner _ _ = 0
-- | MCTS algorithm
mcts :: StdGen -> Node -> (Node, StdGen)
mcts g n | (not.anyEmpty) $ snd z = (n, g) -- there's no need to expand more
| otherwise = (n'', g'')
where n'' = backpropagation sc z'
(sc, b, g'') = simulation g (player $ snd z') (state $ snd z')
(z', g') = expansion g z
z = select ([], n)
-- | iterates the algorithm 'it' times
iter :: Int -> StdGen -> Node -> Node
iter it g n | it <= 0 = n
| otherwise = iter (it-1) g' n'
where (n', g') = mcts g n
-- | play the game with AI
play :: Node -> Board
play Empty = [[]]
play Forbidden = [[]]
play n | goal $ state n = state n
| all (\ni -> ni==Empty || ni==Forbidden) $ children n = state n
| otherwise = play n'
where
n' = children n !! idx
idx = maxConfidence $ children n
{- initial states -}
s0 :: Board
s0 = [[Nothing | _ <- [1..3]] | _ <- [1..3]]
emptyChildren :: [Node]
emptyChildren = replicate 9 Empty
n0 :: Node
n0 = Node 0 0 X s0 emptyChildren
-- | the main function
main = do g <- newStdGen
let n = iter 10000 g n0
print $ play n