library("mdplearning") library("appl") library("pomdpplus") library("tidyverse") library("seewave") ## For KL divergence only knitr::opts_chunk$set(cache = FALSE, message=FALSE) theme_set(theme_bw())
Identify available solutions in the log that match the desired parameters
log_dir <- "https://raw.githubusercontent.com/cboettig/pomdp-solutions-library/master/tuna-100" meta <- appl::meta_from_log(data.frame(model ="ricker", sigma_m = 0.3), log_dir) %>% arrange(r) # drop repeated values meta <- meta[-c(4,9),] meta
Read in the POMDP problem specification from the log
true_i <- 2 setup <- meta[1,] states <- seq(0, setup$max_state, length=setup$n_states) # Vector of all possible states actions <- states obs <- states sigma_g <- setup$sigma_g sigma_m <- setup$sigma_m reward_fn <- function(x,h) pmin(x,h) discount <- setup$discount models <- models_from_log(meta, reward_fn) ## Not valid for non-uniform names(models) <- meta$r
alphas <- alphas_from_log(meta, log_dir)
reformat model solutions for use by the MDP functions as well:
transitions <- lapply(models, `[[`, "transition") reward <- models[[1]]$reward observation <- models[[1]]$observation names(transitions) <- names(models)
Compute the deterministic optimum solution:
f <- f_from_log(meta)[[1]] S_star <- optimize(function(x) x / discount - f(x,0), c(min(states),max(states)))$minimum h <- pmax(states - S_star, 0) policy <- sapply(h, function(h) which.min((abs(h - actions)))) det <- data.frame(policy, value = 1:length(states), state = 1:length(states))
(only evaluate if log has intermediate policies)
inter <- appl:::intermediates_from_log(meta, log_dir = log_dir) if(length(inter[[1]]) > 0){ df1 <- purrr::map_df(1:length(models), function(j){ alphas <- inter[[j]] m <- models[[j]] purrr::map_df(1:length(alphas), function(i) compute_policy(alphas[[i]], m$transition, m$observation, m$reward), .id = "intermediate") }, .id = "model_id") df1 %>% ggplot(aes(states[state], states[state] - actions[policy], col=intermediate)) + geom_line() + facet_wrap(~model_id, scales = "free") + coord_cartesian(ylim = c(0,0.5)) }
unif <- mdp_compute_policy(transitions, reward, discount) prior <- numeric(length(models)) prior[1] <- 1 low <- mdp_compute_policy(transitions, reward, discount, prior) prior <- numeric(length(models)) prior[true_i] <- 1 true <- mdp_compute_policy(transitions, reward, discount, prior) bind_rows(unif = unif, low = low, true = true, det = det, .id = "model") %>% ggplot(aes(states[state], states[state] - actions[policy], col = model)) + geom_line()
Compare a uniform prior to individial cases:
prior <- numeric(length(models)) prior[1] <- 1 low <- compute_plus_policy(alphas, models, prior) prior <- numeric(length(models)) prior[true_i] <- 1 true <- compute_plus_policy(alphas, models, prior) unif <- compute_plus_policy(alphas, models) # e.g. 'planning only' prior <- numeric(length(models)) prior[length(prior)] <- 1 high <- compute_plus_policy(alphas, models, prior) df <- dplyr::bind_rows(low = low, true=true, unif = unif, high = high, det = det, .id = "prior") ggplot(df, aes(states[state], states[state] - actions[policy], col = prior, pch = prior)) + # geom_point(alpha = 0.5, size = 3) + geom_line()
Historical catch and stock
set.seed(123) data("scaled_data") y <- sapply(scaled_data$y, function(y) which.min(abs(states - y))) a <- sapply(scaled_data$a, function(a) which.min(abs(actions - a))) Tmax <- length(y) data("bluefin_tuna") to_mt <- max(bluefin_tuna$total) # 1178363 # scaling factor for data states_mt <- to_mt * states actions_mt <- to_mt * actions year <- 1952:2009 future <- 2009:2067
plus_hindcast <- compare_plus(models = models, discount = discount, obs = y, action = a, alphas = alphas)
mdp_hindcast <- mdp_historical(transitions, reward, discount, state = y, action = a) names(mdp_hindcast$posterior) = names(plus_hindcast$model_posterior)
left_join(rename(plus_hindcast$df, plus = optimal, state = obs), rename(mdp_hindcast$df, mdp = optimal)) %>% mutate("actual catch" = actions_mt[action], "estimated stock" = states_mt[state], plus = actions_mt[plus], mdp = actions_mt[mdp], time = year[time]) %>% select(-state, -action) %>% gather(variable, stock, -time) -> hindcast hindcast %>% ggplot(aes(time, stock, color = variable)) + geom_line(lwd=1) # + geom_point()
# delta function for true model distribution h_star = array(0,dim = length(models)) h_star[2] = 1 ## Fn for the base-2 KL divergence from true model, in a friendly format kl2 <- function(value) seewave::kl.dist(value, h_star, base = 2)[[2]] bind_rows(plus = plus_hindcast$model_posterior, mdp = mdp_hindcast$posterior, .id = "method") %>% mutate(time = year[rep(1:Tmax,2)]) %>% gather(model, value, -time, -method) %>% group_by(time, method) -> df df %>% summarise(kl = kl2(value)) %>% ggplot(aes(time, kl, col = method)) + stat_summary(geom="line", fun.y = mean, lwd = 1)
Show the final belief over models for pomdp and mdp:
barplot(as.numeric(plus_hindcast$model_posterior[Tmax,]))
barplot(as.numeric(mdp_hindcast$posterior[Tmax,]))
posteriors <- bind_rows(plus = plus_hindcast$model_posterior, mdp = mdp_hindcast$posterior, .id = "method") %>% ## include a time column mutate(time = rep(1:Tmax, 2)) %>% ## but into long orientation gather(model, value, -time, -method) posteriors %>% filter(time == Tmax) %>% mutate(model = as.numeric(model)) %>% ggplot(aes(model, value, fill=method)) + geom_bar(stat="identity", position="dodge") + geom_vline(xintercept = meta$r[true_i], lwd = 1) + xlab("r value") + ylab("probability density") + ggtitle("Final belief state")
All forecasts start from final stock, go forward an equal length of time:
x0 <- y[length(y)] # Final stock, Tmax <- length(y) set.seed(123) reps <- 40 mc.cores <- parallel::detectCores()
Note also that forecasts start with the prior belief over states and prior belief over models that was determined from the historical data.
## Need to average over model state_prior = rowSums(plus_hindcast$state_posterior[length(y),, ]) plus_forecast <- parallel::mclapply(1:reps, function(i) sim_plus(models = models, discount = discount, model_prior = as.numeric(plus_hindcast$model_posterior[length(y), ]), #state_prior = as.numeric(), x0 = x0, Tmax = Tmax, true_model = models[[true_i]], alphas = alphas), mc.cores = mc.cores)
We simulate replicates under MDP learning (with observation uncertainty):
set.seed(123) mdp_forecast <- parallel::mclapply(1:reps, function(i) mdp_learning(transition = transitions, reward = models[[true_i]]$reward, model_prior = as.numeric(mdp_hindcast$posterior[length(y),]), discount = discount, x0 = x0, Tmax = Tmax, true_transition = transitions[[true_i]], observation = models[[true_i]]$observation), mc.cores = mc.cores)
historical <- bluefin_tuna[c("tsyear", "total", "catch_landings")] %>% rename(time = tsyear, state = total, action = catch_landings) %>% mutate(method = "historical") bind_rows(plus = plus_forecast %>% map_df(`[[`, "df", .id = "rep"), mdp = mdp_forecast %>% map_df(`[[`, "df", .id = "rep"), .id = "method") %>% select(-value, -obs) %>% mutate(state = states_mt[state], action = actions_mt[action], time = future[time]) %>% bind_rows(historical) %>% rename("catch" = action, "stock" = state) %>% gather(variable, stock, -time, -rep, -method) -> forecast forecast %>% ggplot(aes(time, stock)) + # geom_line(aes(group = interaction(rep,method), color = method), alpha=0.2) + stat_summary(aes(color = method), geom="line", fun.y = mean, lwd=1) + stat_summary(aes(fill = method), geom="ribbon", fun.data = mean_sdl, fun.args = list(mult=1), alpha = 0.25) + #geom_line(data = hindcast, mapping = aes(time, stock, color = variable)) + facet_wrap(~variable, ncol = 1, scales = "free_y")
hindcast %>% rename(method = variable) %>% filter(method %in% c("mdp", "plus")) %>% mutate(variable = "catch", rep = "NA") -> hindcast2 forecast %>% full_join(hindcast2) -> df df %>% ggplot(aes(time, stock)) + stat_summary(aes(color = method), geom="line", fun.y = mean, lwd=1) + stat_summary(aes(fill = method), geom="ribbon", fun.data = mean_sdl, fun.args = list(mult=1), alpha = 0.25) + facet_wrap(~variable, ncol = 1, scales = "free_y")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.