EntropyRegularizedLogisticRegression: Entropy Regularized Logistic Regression

View source: R/EntropyRegularizedLogisticRegression.R

EntropyRegularizedLogisticRegressionR Documentation

Entropy Regularized Logistic Regression

Description

R Implementation of entropy regularized logistic regression implementation as proposed by Grandvalet & Bengio (2005). An extra term is added to the objective function of logistic regression that penalizes the entropy of the posterior measured on the unlabeled examples.

Usage

EntropyRegularizedLogisticRegression(X, y, X_u = NULL, lambda = 0,
  lambda_entropy = 1, intercept = TRUE, init = NA, scale = FALSE,
  x_center = FALSE)

Arguments

X

matrix; Design matrix for labeled data

y

factor or integer vector; Label vector

X_u

matrix; Design matrix for unlabeled data

lambda

l2 Regularization

lambda_entropy

Weight of the labeled observations compared to the unlabeled observations

intercept

logical; Whether an intercept should be included

init

Initial parameters for the gradient descent

scale

logical; Should the features be normalized? (default: FALSE)

x_center

logical; Should the features be centered?

Value

S4 object of class EntropyRegularizedLogisticRegression with the following slots:

w

weight vector

classnames

the names of the classes

References

Grandvalet, Y. & Bengio, Y., 2005. Semi-supervised learning by entropy minimization. In L. K. Saul, Y. Weiss, & L. Bottou, eds. Advances in Neural Information Processing Systems 17. Cambridge, MA: MIT Press, pp. 529-536.

Examples

library(RSSL)
library(ggplot2)
library(dplyr)


# An example where ERLR finds a low-density separator, which is not
# the correct solution.
set.seed(1)
df <- generateSlicedCookie(1000,expected=FALSE) %>% 
  add_missinglabels_mar(Class~.,0.98)

class_lr <- LogisticRegression(Class~.,df,lambda = 0.01)
class_erlr <- EntropyRegularizedLogisticRegression(Class~.,df,
                                lambda=0.01,lambda_entropy = 100)


ggplot(df,aes(x=X1,y=X2,color=Class)) +
  geom_point() +
  stat_classifier(aes(linetype=..classifier..),
                  classifiers = list("LR"=class_lr,"ERLR"=class_erlr)) +
  scale_y_continuous(limits=c(-2,2)) +
  scale_x_continuous(limits=c(-2,2))

df_test <- generateSlicedCookie(1000,expected=FALSE)
mean(predict(class_lr,df_test)==df_test$Class)
mean(predict(class_erlr,df_test)==df_test$Class)





jkrijthe/RSSL documentation built on Jan. 13, 2024, 1:56 a.m.