Reputation: 33
I'm trying to code a GridWorld simulation in Haskell via reinforcement learning. I'm stuck because I keep falling into an infinite loop on line 109. I've been staring at this problem for a week, and I reworked the code from scratch multiple times so I could use some help from someone else's point of view.
This is output of my program:
Initial Grid:
0 | 1 | 2 | 3 | 4 |
1 | 2 | 3 | 4 | 5 |
2 | 3 | 4 | 5 | 6 |
3 | 4 | 5 | 6 | 7 |
4 | 5 | 6 | 7 | 8 |
Training Q-learning agent...
Q-learning training finished.
Final Grid:
Code is executed until line 109 (excluded).
visualizeGrid (\(x, y) -> maximum [finalQTable ((x, y), a) | a <- [minBound .. maxBound]])
Code is compiled fine. In error window (I'm using online enviroment) I got message Main: <<loop>>.
All the code below:
import System.Random
import Data.List (maximumBy)
import Data.Ord (comparing)
import Control.Monad (foldM)
import Debug.Trace (trace)
type Position = (Int, Int)
type Reward = Float
type GridWorld = Position -> Reward
data Action = Up | Down | MoveLeft | MoveRight deriving (Eq, Enum, Bounded)
instance Show Action where
show Up = "↑"
show Down = "↓"
show MoveLeft = "←"
show MoveRight = "→"
step :: GridWorld -> Position -> Action -> (Position, Reward)
step world (x, y) action = case action of
Up -> ((x, max 0 (y-1)), world (x, max 0 (y-1)))
Down -> ((x, min 4 (y+1)), world (x, min 4 (y+1)))
MoveLeft -> ((max 0 (x-1), y), world (max 0 (x-1), y))
MoveRight-> ((min 4 (x+1), y), world (min 4 (x+1), y))
type QTable = ((Position, Action) -> Reward)
trainQ :: GridWorld -> QTable -> Float -> Float -> Int -> IO QTable
trainQ world qtable alpha gamma episodes = do
gen <- newStdGen
let actions = [minBound .. maxBound] :: [Action]
positions = [(x, y) | x <- [0..4], y <- [0..4]]
finalQTable <- snd <$> foldM (\(g, q) _ -> do
let (q', g') = trainEpisode world q alpha gamma actions positions g
return (g', q')) (gen, qtable) [1..episodes]
return finalQTable
trainEpisode :: GridWorld -> QTable -> Float -> Float -> [Action] -> [Position] -> StdGen -> (QTable, StdGen)
trainEpisode world qtable alpha gamma actions positions gen =
let (startPos, newGen) = randomR (0, length positions - 1) gen
startState = positions !! startPos
(_, finalQTable, _) = foldl (\(prevPos, q, g) _ ->
let (newPos, _) = step world prevPos (chooseAction q newPos actions)
(updatedQTable, _) = trainStep world q alpha gamma actions newPos
in (newPos, updatedQTable, g))
(startState, qtable, newGen) [1..10]
in (finalQTable, newGen)
trainStep :: GridWorld -> QTable -> Float -> Float -> [Action] -> Position -> (QTable, Position)
trainStep world qtable alpha gamma actions pos =
let action = chooseAction qtable pos actions
(newPos, reward) = step world pos action
oldValue = qtable (pos, action)
futureValue = maximum [qtable (newPos, a) | a <- actions]
newValue = oldValue + alpha * (reward + gamma * futureValue - oldValue)
qtable' = \s -> if s == (pos, action) then newValue else qtable s
in (qtable', newPos)
chooseAction :: QTable -> Position -> [Action] -> Action
chooseAction qtable (x, y) actions =
let validActions = filter (\action -> isValidAction action (x, y)) actions
bestAction = maximumBy (comparing (\a -> qtable ((x, y), a))) validActions
in bestAction
isValidAction :: Action -> Position -> Bool
isValidAction Up (_, y) = y > 0
isValidAction Down (_, y) = y < 4
isValidAction MoveLeft (x, _) = x > 0
isValidAction MoveRight (x, _) = x < 4
runAgent :: GridWorld -> QTable -> Position -> [Action] -> IO ()
runAgent world qtable pos actions = do
putStrLn $ "Current Position: " ++ show pos
let action = chooseAction qtable pos actions
putStrLn $ "Chosen Action: " ++ show action
let (newPos, reward) = step world pos action
putStrLn $ "Action: " ++ show action ++ ", Reward: " ++ show reward
if newPos == pos
then do
putStrLn "Agent is stuck! Generating random action."
gen <- newStdGen
let (randomActionIndex, newGen) = randomR (0, length actions - 1) gen
randomAction = actions !! randomActionIndex
putStrLn $ "Random Action: " ++ show randomAction
let (newPos', reward') = step world pos randomAction
putStrLn $ "Random Action Result: " ++ show randomAction ++ ", New Position: " ++ show newPos' ++ ", Reward: " ++ show reward'
runAgent world qtable newPos' actions
else do
putStrLn $ "New Position: " ++ show newPos
runAgent world qtable newPos actions
visualizeGrid :: GridWorld -> IO ()
visualizeGrid world = mapM_ putStrLn [concat [show (round (world (x, y))) ++ " | " | x <- [0..4]] | y <- [0..4]]
gridWorld :: GridWorld
gridWorld (x, y) = fromIntegral (x + y)
main :: IO ()
main = do
putStrLn "Initial Grid:"
visualizeGrid gridWorld
putStrLn "Training Q-learning agent..."
let initialQTable = \_ -> 0.0
alpha = 0.1
gamma = 0.9
episodes = 4
finalQTable <- trainQ gridWorld initialQTable alpha gamma episodes
putStrLn "Q-learning training finished."
putStrLn "Final Grid:"
visualizeGrid (\(x, y) -> maximum [finalQTable ((x, y), a) | a <- [minBound .. maxBound]])
putStrLn "Running agent:"
runAgent gridWorld finalQTable (0, 0) [minBound .. maxBound]
Edit:
On line 42:
let (newPos, _) = step world prevPos (chooseAction q newPos actions)
the newPosin call of chooseAction is supposed to be prevPos. So this line after edit should look like:
let (newPos, _) = step world prevPos (chooseAction q prevPos actions)
However, this does not make the code run properly yet. After this change, I get time out. During my "debugging" attempts, I've been switching basicly between two scenarios: 1) Inifite loop case; and 2) Command timed out case. I assume the origin of the problem has something to do with lazy evaluation, but it is just an assumption.
My question is simple yet hard: how do I make my code work properly?
Upvotes: 2
Views: 66
Reputation: 116174
The <<loop>>
is probably due to this line
let (newPos, _) = step world prevPos (chooseAction q newPos actions)
-- ^^^^^^ ^^^^^^ --
since newPos
is recursively defined in terms of itself. Did you mean prevPos
in the chooseAction
call?
Upvotes: 2