predict_dt: Performs inference over a test dataset with a GBN

View source: R/dbn_inference.R

predict_dtR Documentation

Performs inference over a test dataset with a GBN

Description

This function performs inference over each row of a folded data.table, plots the results and gives metrics of the accuracy of the predictions. Given that only a single row is predicted, the horizon of the prediction is at most 1. This function is also called by the generic predict method for "dbn.fit" objects. For long term forecasting, please refer to the forecast_ts function.

Usage

predict_dt(fit, dt, obj_nodes, verbose = T, look_ahead = F)

Arguments

fit

the fitted bn

dt

the test dataset

obj_nodes

the nodes that are going to be predicted. They are all predicted at the same time

verbose

if TRUE, displays the metrics and plots the real values against the predictions

look_ahead

boolean that defines whether or not the values of the variables in t_0 should be used when predicting, even if they are not present in obj_nodes. This decides if look-ahead bias is introduced or not.

Value

a data.table with the prediction results for each row

Examples

size = 3
data(motor)
dt_train <- motor[200:900]
dt_val <- motor[901:1000]

# With a DBN
obj <- c("pm_t_0")
net <- learn_dbn_struc(dt_train, size)
f_dt_train <- fold_dt(dt_train, size)
f_dt_val <- fold_dt(dt_val, size)
fit <- fit_dbn_params(net, f_dt_train, method = "mle-g")
res <- suppressWarnings(predict_dt(fit, f_dt_val, obj_nodes = obj, verbose = FALSE))

# With a Gaussian BN directly from bnlearn
obj <- c("pm")
net <- bnlearn::mmhc(dt_train)
fit <- bnlearn::bn.fit(net, dt_train, method = "mle-g")
res <- suppressWarnings(predict_dt(fit, dt_val, obj_nodes = obj, verbose = FALSE))

dbnR documentation built on Oct. 5, 2022, 1:07 a.m.