R/gridworld.R

Defines functions gridworld_animate gridworld_plot_transition_graph gridworld_plot_policy gridworld_matrix gridworld_rc2s gridworld_s2rc gridworld_maze_MDP gridworld_init

Documented in gridworld_animate gridworld_init gridworld_matrix gridworld_maze_MDP gridworld_plot_policy gridworld_plot_transition_graph gridworld_rc2s gridworld_s2rc

#' Helper Functions for Gridworld MDPs
#'
#' Helper functions for gridworld MDPs to convert between state names and
#' gridworld positions, and for visualizing policies.
#'
#' Gridworlds are implemented with state names `s(row,col)`, where
#' `row` and `col` are locations in the matrix representing the gridworld.
#' The actions are `"up"`, `"right"`,  `"down"`, and  `"left"`.
#'
#' `gridworld_init()` initializes a new gridworld creating a matrix
#' of states with the given dimensions. Other action names
#' can be specified, but they must have the same effects in the same order
#' as above. Unreachable states (walls) and absorbing state can be defined.
#' This information can be used to build a custom gridworld MDP.
#'
#' Several helper functions are provided
#' to use states, look at the state layout, and plot policies on the
#' gridworld.
#'
#' `gridworld_maze_MDP()` helps to easily define maze-like gridworld MDPs.
#' By default, the goal state is absorbing, but with `restart = TRUE`, the
#' agent restarts the problem at the start state every time it reaches the goal
#' and receives the reward. Note that this implies that the goal state itself
#' becomes unreachable.
#' 
#' `gridworld_animate()` applies algorithms from [solve_MDP()] iteration 
#' by iteration and visualized the state utilities. This helps to understand 
#' how the algorithms work.
#'
#' @name gridworld
#' @aliases gridworld
#' @family gridworld
#' @family MDP
#' @examples
#' # Defines states, actions and a transition model for a standard gridworld
#' gw <- gridworld_init(dim = c(7,7),
#'                 unreachable_states = c("s(2,2)", "s(7,3)", "s(3,6)"),
#'                 absorbing_states = "s(4,4)",
#'                 labels = list("s(4,4)" = "Black Hole")
#'                 )
#'
#' gw$states
#' gw$actions
#' gw$info
#'
#' # display the state labels in the gridworld
#' gridworld_matrix(gw)
#' gridworld_matrix(gw, what = "label")
#' gridworld_matrix(gw, what = "reachable")
#' gridworld_matrix(gw, what = "absorbing")
#'
#' # a transition function for regular moves in the gridworld is provided
#' gw$transition_prob("right", "s(1,1)", "s(1,2)")
#' gw$transition_prob("right", "s(2,1)", "s(2,2)")  ### we cannot move into an unreachable state
#' gw$transition_prob("right", "s(2,1)", "s(2,1)")  ### but the agent stays in place
#'
#' # convert between state names and row/column indices
#' gridworld_s2rc("s(1,1)")
#' gridworld_rc2s(c(1,1))
#'
#' # The information in gw can be used to build a custom MDP.
#'
#' # We modify the standard transition function so there is a 50% chance that
#' # you will get sucked into the black hole from the adjacent squares.
#' trans_black_hole <- function(action = NA, start.state = NA, end.state = NA) {
#'   # ignore the action next to the black hole
#'   if (start.state %in% c("s(3,3)", "s(3,4)", "s(3,5)", "s(4,3)", "s(4,5)",
#'                          "s(5,3)", "s(5,4)", "s(5,5)")) {
#'         if(end.state == "s(4,4)")
#'             return(.5)
#'         else
#'             return(gw$transition_prob(action, start.state, end.state) * .5)
#'   }
#'
#'   # use the standard gridworld movement
#'   gw$transition_prob(action, start.state, end.state)
#' }
#'
#' black_hole <- MDP(states = gw$states,
#'   actions = gw$actions,
#'   transition_prob = trans_black_hole,
#'   reward = rbind(R_(value = +1), R_(end.state = "s(4,4)", value = -100)),
#'   info = gw$info,
#'   name = "Black hole"
#'   )
#'
#' black_hole
#'
#' gridworld_plot_transition_graph(black_hole)
#'
#' # solve the problem
#' sol <- solve_MDP(black_hole)
#' gridworld_matrix(sol, what = "values")
#' gridworld_plot_policy(sol)
#' # the optimal policy is to fly around, but avoid the black hole.
#'
#' # Build a Maze: The Dyna Maze from Chapter 8 in the RL book
#'
#' Dyna_maze <- gridworld_maze_MDP(
#'                 dim = c(6,9),
#'                 start = "s(3,1)",
#'                 goal = "s(1,9)",
#'                 walls = c("s(2,3)", "s(3,3)", "s(4,3)",
#'                           "s(5,6)",
#'                           "s(1,8)", "s(2,8)", "s(3,8)"),
#'                 restart = TRUE,
#'                 discount = 0.95,
#'                 name = "Dyna Maze",
#'                 )
#' Dyna_maze
#'
#' gridworld_matrix(Dyna_maze)
#' gridworld_matrix(Dyna_maze, what = "labels")
#'
#' gridworld_plot_transition_graph(Dyna_maze)
#' # Note that the problems resets if the goal state would be reached.
#'
#' sol <- solve_MDP(Dyna_maze)
#'
#' gridworld_matrix(sol, what = "values")
#' gridworld_matrix(sol, what = "actions")
#' gridworld_plot_policy(sol)
#' gridworld_plot_policy(sol, actions = "label", cex = 1, states = FALSE)
#' 
#' # visualize the first 3 iterations of value iteration
#' gridworld_animate(Dyna_maze, method = "value", n = 3)
#' @param dim vector of length two with the x and y extent of the gridworld.
#' @param action_labels vector with four action labels that move the agent up, right, down,
#'   and left.
#' @param unreachable_states a vector with state labels for unreachable states.
#'     These states will be excluded.
#' @param absorbing_states a vector with state labels for absorbing states.
#' @export
gridworld_init <-
  function(dim,
           action_labels =  c("up", "right", "down", "left"),
           unreachable_states = NULL,
           absorbing_states = NULL,
           labels = NULL) {
    S <- as.vector(outer(
      seq_len(dim[1]),
      seq_len(dim[2]),
      FUN = function(x, y)
        paste0("s(", x, ",", y, ")")
    ))
    
    if (!is.null(unreachable_states))
      S <- setdiff(S, unreachable_states)
    
    T <- function(action,
                  start.state = NULL,
                  end.state = NULL) {
      ai <- pmatch(action, action_labels)
      
      # stay in place for unknown actions
      if (is.na(ai)) {
        #warning("Unknown action", action)
        return(as.integer(end.state == start.state))
      }
      
      if (!is.null(absorbing_states) &&
          start.state %in% absorbing_states)
        return(as.integer(end.state == start.state))
      
      rc <- gridworld_s2rc(start.state)
      rc <- switch(ai,
                   rc + c(-1,  0),
                   rc + c(0,+1),
                   rc + c(+1,  0),
                   rc + c(0,-1),
                   )
                   
                   es <-  gridworld_rc2s(rc)
                   if (!(es %in% S))
                     es <- start.state
                   as.integer(es == end.state)
    }
    
    list(
      states = S,
      actions = action_labels,
      transition_prob = T,
      reward = rbind(R_(value = 0)),
      info = list(gridworld_dim = dim, gridworld_labels = labels)
    )
  }

#' @rdname gridworld
#' @param start,goal labels for the start state and the goal state.
#' @param walls a vector with state labels for walls. Walls will
#'              become unreachable states.
#' @param goal_reward reward to transition to the goal state.
#' @param step_cost cost of each action that does not lead to the goal state.
#' @param restart logical; if `TRUE` then the problem automatically restarts when
#'      the agent reaches the goal state.
#' @param discount,horizon MDP discount factor, and horizon.
#' @param info A list with additional information. Has to contain the gridworld
#'      dimensions as element `gridworld_dim`.
#' @param name a string to identify the MDP problem.
#' @export
gridworld_maze_MDP <- function(dim,
                               start,
                               goal,
                               walls = NULL,
                               action_labels =  c("up", "right", "down", "left"),
                               goal_reward = 1,
                               step_cost = 0,
                               restart = FALSE,
                               discount = 0.9,
                               horizon = Inf,
                               info = NULL,
                               name = NA) {
  
  state_labels <-  structure(list(rep_len(
    "Goal", length(goal)
  )), names = goal)
  
  if (start != "uniform") {
    start_labels <- c(state_labels, structure(list(rep_len(
      "Start", length(start)
    )), names = start))
  }
  
  gw <-
    gridworld_init(
      dim,
      unreachable_states = walls,
      absorbing_states = goal,
      labels = state_labels
    )
  
  if (!restart)
    return(
      MDP(
        states = gw$states,
        actions = gw$actions,
        transition_prob = gw$transition_prob,
        reward = rbind(
          R_(value = -step_cost),
          R_(end.state = goal, value = goal_reward)
        ),
        discount = discount,
        horizon = horizon,
        start = start,
        info = gw$info,
        name = name
      )
    )
  
  # with restarts. We add the action restart which is only available in the goal state. 
  if (start == "uniform")
    start <- gw$states
  
  # find all states that can lead to the goal
  #v <- Vectorize(gw$transition_prob, vectorize.args = "start.state")
  #reach <- sapply(gw$actions, v, gw$states, goal)
  #reach <- reach[!rownames(reach) %in% goal, ]
  #lead.states <- names(which(rowSums(reach) > 0))
  
  # redirect to the start state
  trans_restart <- function(action = NA,
                            start.state = NA,
                            end.state = NA) {
    if (start.state %in% goal) {
      if (end.state %in% start) 
        return(1/length(start))
      else
        return(0)
    }
    
    # regular move
    gw$transition_prob(action, start.state, end.state)
  }
  
  
  # note the goal state is now unreachable
  MDP(
    states = gw$states,
    actions = c(gw$actions, "restart"),
    transition_prob = trans_restart,
    reward = rbind(
      R_(value = -step_cost),
      R_(action = "restart", value = -Inf),
      R_(end.state = goal, value = goal_reward),
      R_(start.state = goal, value = -Inf),
      R_(start.state = goal, action = "restart", value = 0)
    ),
    discount = discount,
    horizon = horizon,
    start = start,
    info = gw$info,
    name = name
  )
}

#' @rdname gridworld
#' @param model,x a solved gridworld MDP.
#' @param s a state label.
#' @param rc a vector of length two with the row and column coordinate of a
#'   state in the gridworld matrix.
#' @export
gridworld_s2rc <- function(s) {
  rc <- as.integer(strsplit(s, "s\\(|,|\\)")[[1]][-1])
  if (length(rc) != 2 || any(is.na(rc)))
    stop("Malformed gridworld state label ",
         sQuote(s),
         ". Needs to be 's(<row>,<col>)'.")
  rc
}

#' @rdname gridworld
#' @export
gridworld_rc2s <- function(rc)
  paste0("s(", rc[1], ",", rc[2], ")")

#' @rdname gridworld
#' @param what What should be returned in the matrix. Options are:
#'  `"states"`, `"labels"`, `"values"`, `"actions"`, `"absorbing"`, and
#'  `"reachable"`.
#' @export
gridworld_matrix <- function(model,
                             epoch = 1L,
                             what = "states") {
  what <- match.arg(what,
                    c(
                      "states",
                      "labels",
                      "values",
                      "actions",
                      "absorbing",
                      "reachable"
                    ))
  nrows <- model$info$gridworld_dim[1]
  ncols <- model$info$gridworld_dim[2]
  
  all_states <- gridworld_init(dim = c(nrows, ncols))$states
  
  x <- switch(
    what,
    states = {
      l <- structure(rep(NA_character_, length(all_states)),
                     names = all_states)
      l[model$states] <- model$states
      l
    },
    
    labels = {
      l <- structure(rep("", length(all_states)),
                     names = all_states)
      l[!reachable_states(model)] <- "X"
      l[!(all_states %in% model$states)] <- "X"
      labels <- model$info$gridworld_labels
      l[names(labels)] <- unlist(labels)
      
      l
    },
    
    values = {
      l <- structure(rep(NA_real_, length(all_states)),
                     names = all_states)
      p <- policy(model, drop = FALSE)[[epoch]]
      l[p$state] <- p$U
      l
    },
    
    actions = {
      l <- structure(rep(NA_character_, length(all_states)),
                     names = all_states)
      p <- policy(model, drop = FALSE)[[epoch]]
      l[p$state] <- as.character(p$action)
      l
    },
    
    absorbing = {
      l <- structure(rep(TRUE, length(all_states)),
                     names = all_states)
      l[model$states] <- absorbing_states(model)
      l
    },
    reachable = {
      l <- structure(rep(FALSE, length(all_states)),
                     names = all_states)
      l[model$states] <- reachable_states(model)
      l
    }
  )
  
  matrix(x, nrow = nrows)
}

#' @rdname gridworld
#' @param epoch epoch for unconverged finite-horizon solutions.
#' @param actions how to show actions. Options are: 
#'  simple `"character"`, `"unicode"` arrows (needs to be supported by the used font), 
#'  `"label"` of the action, and  `"none"` to suppress showing the action.
#' @param states logical; show state names.
#' @param labels logical; show state labels.
#' @param absorbing_state_action logical; show the value and the action for absorbing states.
#' @param main logical; main title.
#' @param cex expansion factor for the action.
#' @param offset move the state labels out of the way (in fractions of a character width).
#' @param lines logical; draw lines to separate states.
#' @param ... further arguments are passed on to [graphics::image()].
#' @importFrom graphics image text box abline
#' @export
gridworld_plot_policy <-
  function(model,
           epoch = 1L,
           actions = "character",
           states = FALSE,
           labels = TRUE,
           absorbing_state_action = FALSE,
           main = NULL,
           cex = 1,
           offset = .5,
           lines = TRUE,
           ...) {
    actions <- match.arg(actions, c("character", "unicode", "label", "none"))
    
    if (is.null(main))
      main <- paste("Policy:", model$name, paste0("(", model$solution$method,")"))
    
    U <-
      gridworld_matrix(model, epoch = epoch, what = "value")
    
    #if (!absorbing_state_action)
    #  U[gridworld_matrix(model, what = "absorbing")] <- NA
    
    image(t(U)[, rev(seq_len(nrow(U))), drop = FALSE], main = main, axes = FALSE, ...)
    
    nrows <- model$info$gridworld_dim[1]
    ncols <- model$info$gridworld_dim[2]
    if (lines) {
      box()
      frac <- 1 / ((nrows - 1L) * 2)
      abline(h = (((1:nrows) * 2L) - 1L) * frac)
      frac <- 1 / ((ncols - 1L) * 2)
      abline(v = (((1:ncols) * 2L) - 1L) * frac)
    }
    
    g <- expand.grid(y = (nrows - 1L):0 / (nrows - 1L),
                     x = 0:(ncols - 1L) / (ncols - 1L))
    
    if (nrows < 2)
      g$y <- 0
    if (ncols < 2)
      g$x <- 0
     
    g$state <-
      as.vector(gridworld_matrix(model, what = "states"))
    g$labels <-
      as.vector(gridworld_matrix(model, what = "labels"))
    g$actions <-
      as.vector(gridworld_matrix(model, epoch = epoch, what = "actions"))
    
    if (!absorbing_state_action)
      g$actions[gridworld_matrix(model, what = "absorbing")] <-
      NA
    
    if (states)
      text(
        g$x,
        g$y,
        g$state,
        pos = 3,
        offset = offset,
        cex = .5 * cex
      )
    #text(g$x, g$y, g$state, pos = 3, cex = .8)
    
    if (labels)
      text(
        g$x,
        g$y,
        g$label,
        pos = 1,
        offset = offset,
        cex = .5 * cex
      )
    
    g$actions <- switch(actions,
      character =  as.character(factor(
        g$actions,
        levels = c("up", "right", "down", "left"),
        labels = c("^", ">", "v", "<"))),
      
      unicode = as.character(factor(
        g$actions,
        levels = c("up", "right", "down", "left"),
        labels = c("\U2191", "\U2192", "\U2193", "\U2190"))),
      
      label = g$actions,
      
      none = NA
      )
     
    text(g$x, g$y, g$actions, cex = cex)
  }

#' @rdname gridworld
#' @param hide_unreachable_states logical; do not show unreachable states.
#' @param remove.loops logical; do not show transitions from a state back to itself.
#' @param vertex.color,vertex.shape,vertex.size,vertex.label,edge.arrow.size
#'  see `igraph::igraph.plotting` for details. Set `vertex.label = NULL` to show the
#'  state labels on the graph.
#' @param margin a single number specifying the margin of the plot. Can be used if the
#'   graph does not fit inside the plotting area.
#' @param main a main title for the plot. Defaults to the name of the problem.
#' @param ... further arguments are passed on to `igraph::plot.igraph()`.
#' @export
gridworld_plot_transition_graph <-
  function(x,
           hide_unreachable_states = TRUE,
           remove.loops = TRUE,
           vertex.color = "gray",
           vertex.shape = "square",
           vertex.size = 10,
           vertex.label = NA,
           edge.arrow.size = .3,
           margin = .2,
           main = NULL,
           ...) {
    g <- transition_graph(x)
    
    layout <- t(sapply(x$states, gridworld_s2rc))[, 2:1] *
      cbind(rep(1, length(x$states)), -1)
    
    if (hide_unreachable_states) {
      reachable <- reachable_states(x)
      g <- induced_subgraph(g, V(g)[reachable])
      layout <- layout[reachable, , drop = FALSE]
    }
    
    V(g)$color <- vertex.color
    V(g)$shape <- vertex.shape
    V(g)$size <- vertex.size
    
    if (remove.loops)
      g <- igraph::simplify(g, remove.loops = TRUE)
    
    asp <- x$info$gridworld_dim[2] / x$info$gridworld_dim[1]
     
    if (is.null(main))
      main <- paste("Transition Graph:", x$name)
    
    plot(
      g,
      layout = norm_coords(layout, xmin = -asp, xmax = asp),
      rescale = FALSE,
      xlim = c(-asp * (1 + margin), asp *  (1 + margin)),
      ylim = c(-(1 + margin),  (1 + margin)),
      edge.arrow.size = edge.arrow.size,
      vertex.label = vertex.label,
      edge.label = NA,
      main = main,
      ...
    )
    
    invisible(g)
  }

#' @rdname gridworld
#' @param n number of iterations to animate.
#' @param method a MDP solution method for [solve_MDP()].
#' @param zlim limits for visualizing the state value.
#' @export
gridworld_animate <- function(x,
                              method,
                              n,
                              zlim = NULL,
                              ...) {
  U <- NULL
  
  for (i in seq(n)) {
    sol <- suppressWarnings(solve_MDP(
      x,
      N = 1,
      U = U,
      method = method,
      ...
    ))
    U <- value_function(sol)
    
    if (!is.null(zlim))
      gridworld_plot_policy(
        sol,
        main = sol$solution$method,
        sub = paste("Iteration", i),
        zlim = zlim
      )
    else
      gridworld_plot_policy(
        sol,
        main = sol$solution$method,
        sub = paste("Iteration", i)
      )
  }
}

Try the pomdp package in your browser

Any scripts or data that you put into this service are public.

pomdp documentation built on May 29, 2024, 2:04 a.m.