View source: R/dbn_inference.R
predict_dt | R Documentation |
This function performs inference over each row of a folded data.table,
plots the results and gives metrics of the accuracy of the predictions. Given
that only a single row is predicted, the horizon of the prediction is at most 1.
This function is also called by the generic predict method for "dbn.fit"
objects. For long term forecasting, please refer to the
forecast_ts
function.
predict_dt(fit, dt, obj_nodes, verbose = T, look_ahead = F)
fit |
the fitted bn |
dt |
the test dataset |
obj_nodes |
the nodes that are going to be predicted. They are all predicted at the same time |
verbose |
if TRUE, displays the metrics and plots the real values against the predictions |
look_ahead |
boolean that defines whether or not the values of the variables in t_0 should be used when predicting, even if they are not present in obj_nodes. This decides if look-ahead bias is introduced or not. |
a data.table with the prediction results for each row
size = 3
data(motor)
dt_train <- motor[200:900]
dt_val <- motor[901:1000]
# With a DBN
obj <- c("pm_t_0")
net <- learn_dbn_struc(dt_train, size)
f_dt_train <- fold_dt(dt_train, size)
f_dt_val <- fold_dt(dt_val, size)
fit <- fit_dbn_params(net, f_dt_train, method = "mle-g")
res <- suppressWarnings(predict_dt(fit, f_dt_val, obj_nodes = obj, verbose = FALSE))
# With a Gaussian BN directly from bnlearn
obj <- c("pm")
net <- bnlearn::mmhc(dt_train)
fit <- bnlearn::bn.fit(net, dt_train, method = "mle-g")
res <- suppressWarnings(predict_dt(fit, dt_val, obj_nodes = obj, verbose = FALSE))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.