Description Usage Format Methods Arguments Examples
Base class for all estimators. Defines minimal set of members and methods(with signatires) which have to be implemented in child classes.
1 |
R6Class
object.
$fit(x, y, ...)
$predict(x, ...)
Makes predictions on new data (after model was trained)
A matrix like object, should inherit from Matrix
or matrix
.
Allowed classes should be defined in child classes.
target - usually vector
, but also can be a matrix like object.
Allowed classes should be defined in child classes.
additional parameters with default values
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | SimpleLinearModel = R6::R6Class(
classname = "mlapiSimpleLinearModel",
inherit = mlapi::mlapiEstimation,
public = list(
initialize = function(tol = 1e-7) {
private$tol = tol
super$set_internal_matrix_formats(dense = "matrix", sparse = NULL)
},
fit = function(x, y, ...) {
x = super$check_convert_input(x)
stopifnot(is.vector(y))
stopifnot(is.numeric(y))
stopifnot(nrow(x) == length(y))
private$n_features = ncol(x)
private$coefficients = .lm.fit(x, y, tol = private$tol)[["coefficients"]]
},
predict = function(x) {
stopifnot(ncol(x) == private$n_features)
x %*% matrix(private$coefficients, ncol = 1)
}
),
private = list(
tol = NULL,
coefficients = NULL,
n_features = NULL
))
set.seed(1)
model = SimpleLinearModel$new()
x = matrix(sample(100 * 10, replace = TRUE), ncol = 10)
y = sample(c(0, 1), 100, replace = TRUE)
model$fit(as.data.frame(x), y)
res1 = model$predict(x)
# check pipe-compatible S3 interface
res2 = predict(x, model)
identical(res1, res2)
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.