EntropyRegularizedLogisticRegressionSSLR: General Interface for EntropyRegularizedLogisticRegression...

Description Usage Arguments References Examples

View source: R/EntropyRegularizedLogisticRegression.R

Description

model from RSSL package R Implementation of entropy regularized logistic regression implementation as proposed by Grandvalet & Bengio (2005). An extra term is added to the objective function of logistic regression that penalizes the entropy of the posterior measured on the unlabeled examples.

Usage

1
2
3
4
5
6
7
8
EntropyRegularizedLogisticRegressionSSLR(
  lambda = 0,
  lambda_entropy = 1,
  intercept = TRUE,
  init = NA,
  scale = FALSE,
  x_center = FALSE
)

Arguments

lambda

l2 Regularization

lambda_entropy

Weight of the labeled observations compared to the unlabeled observations

intercept

logical; Whether an intercept should be included

init

Initial parameters for the gradient descent

scale

logical; Should the features be normalized? (default: FALSE)

x_center

logical; Should the features be centered?

References

Grandvalet, Y. & Bengio, Y., 2005. Semi-supervised learning by entropy minimization. In L. K. Saul, Y. Weiss, & L. Bottou, eds. Advances in Neural Information Processing Systems 17. Cambridge, MA: MIT Press, pp. 529-536.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
library(tidyverse)
library(caret)
library(tidymodels)
library(SSLR)

data(breast)

set.seed(1)
train.index <- createDataPartition(breast$Class, p = .7, list = FALSE)
train <- breast[ train.index,]
test  <- breast[-train.index,]

cls <- which(colnames(breast) == "Class")

#% LABELED
labeled.index <- createDataPartition(breast$Class, p = .2, list = FALSE)
train[-labeled.index,cls] <- NA


m <- EntropyRegularizedLogisticRegressionSSLR() %>% fit(Class ~ ., data = train)


#Accuracy
predict(m,test) %>%
  bind_cols(test) %>%
  metrics(truth = "Class", estimate = .pred_class)

Example output

── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
✔ ggplot2 3.3.2purrr   0.3.4tibble  3.0.4dplyr   1.0.2tidyr   1.1.2stringr 1.4.0readr   1.4.0forcats 0.5.0
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()dplyr::lag()    masks stats::lag()
Loading required package: lattice

Attaching package:caretThe following object is masked frompackage:purrr:

    lift

── Attaching packages ────────────────────────────────────── tidymodels 0.1.2 ──
✔ broom     0.7.2recipes   0.1.15dials     0.0.9rsample   0.0.8infer     0.5.3tune      0.1.2modeldata 0.1.0workflows 0.2.1parsnip   0.1.4yardstick 0.0.7 
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard()        masks purrr::discard()dplyr::filter()          masks stats::filter()recipes::fixed()         masks stringr::fixed()dplyr::lag()             masks stats::lag()caret::lift()            masks purrr::lift()yardstick::precision()   masks caret::precision()yardstick::recall()      masks caret::recall()yardstick::sensitivity() masks caret::sensitivity()yardstick::spec()        masks readr::spec()yardstick::specificity() masks caret::specificity()recipes::step()          masks stats::step()
# A tibble: 2 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary        0.706 
2 kap      binary        0.0899
Warning message:
In matrix(object@w, nrow = ncol(X)) :
  data length [34] is not a sub-multiple or multiple of the number of rows [35]

SSLR documentation built on July 22, 2021, 9:08 a.m.