mlapiEstimation: Base abstract class for all classification/regression models

Description Usage Format Methods Arguments Examples

Description

Base class for all estimators. Defines minimal set of members and methods(with signatires) which have to be implemented in child classes.

Usage

1

Format

R6Class object.

Methods

$fit(x, y, ...)
$predict(x, ...)

Makes predictions on new data (after model was trained)

Arguments

x

A matrix like object, should inherit from Matrix or matrix. Allowed classes should be defined in child classes.

y

target - usually vector, but also can be a matrix like object. Allowed classes should be defined in child classes.

...

additional parameters with default values

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
SimpleLinearModel = R6::R6Class(
classname = "mlapiSimpleLinearModel",
inherit = mlapi::mlapiEstimation,
public = list(
  initialize = function(tol = 1e-7) {
    private$tol = tol
    super$set_internal_matrix_formats(dense = "matrix", sparse = NULL)
  },
  fit = function(x, y, ...) {
    x = super$check_convert_input(x)
   stopifnot(is.vector(y))
   stopifnot(is.numeric(y))
   stopifnot(nrow(x) == length(y))

   private$n_features = ncol(x)
   private$coefficients = .lm.fit(x, y, tol = private$tol)[["coefficients"]]
 },
 predict = function(x) {
   stopifnot(ncol(x) == private$n_features)
   x %*% matrix(private$coefficients, ncol = 1)
 }
),
private = list(
  tol = NULL,
  coefficients = NULL,
  n_features = NULL
))
set.seed(1)
model = SimpleLinearModel$new()
x = matrix(sample(100 * 10, replace = TRUE), ncol = 10)
y = sample(c(0, 1), 100, replace = TRUE)
model$fit(as.data.frame(x), y)
res1 = model$predict(x)
# check pipe-compatible S3 interface
res2 = predict(x, model)
identical(res1, res2)

mlapi documentation built on May 2, 2019, 6:59 a.m.