Generalized Linear Model implementation in Java
Package implements the generalized linear model in Java
Install
Add the following to dependencies of your pom file:
<dependency>
<groupId>com.github.chen0040</groupId>
<artifactId>java-glm</artifactId>
<version>1.0.6</version>
</dependency>
Features
The current implementation of GLM supports as many distribution families as glm package in R:
- Normal
- Exponential
- Gamma
- InverseGaussian
- Poisson
- Bernouli
- Binomial
- Categorical
- Multinomial
For the solvers, the current implementation of GLM supports a number of variants of the iteratively re-weighted least squares estimation algorithm:
- IRLS
- IRLS with QR factorization
- IRLS with SVD factorization
Usage
Step 1: Create and train the glm against the training data in step 1
Suppose you want to create logistic regression model from GLM and train the logistic regression model against the data frame
import com.github.chen0040.glm.solvers.Glm;
import com.github.chen0040.glm.enums.GlmSolverType;
trainingData = loadTrainingData();
Glm glm = Glm.logistic();
glm.setSolverType(GlmSolverType.GlmIrls);
glm.fit(trainingData);
The "trainingData" is a data frame (Please refers to this link on how to create a data frame from file or from scratch)
The line "Glm.logistic()" create the logistic regression model, which can be easily changed to create other regression models (For example, calling "Glm.linear()" create a linear regression model)
The line "glm.fit(..)" performs the GLM training.
Step 2: Use the trained regression model to predict on new data
The trained glm can then run on the testing data, below is a java code example for logistic regression:
testingData = loadTestingData();
for(int i = 0; i < testingData.rowCount(); ++i){
boolean predicted = glm.transform(testingData.row(i)) > 0.5;
boolean actual = frame.row(i).target() > 0.5;
System.out.println("predicted(Irls): " + predicted + "\texpected: " + actual);
}
The "testingData" is a data frame
The line "glm.transform(..)" perform the regression
Sample code
Sample code for linear regression
The sample code below shows the linear regression example
DataQuery.DataFrameQueryBuilder schema = DataQuery.blank()
.newInput("x1")
.newInput("x2")
.newOutput("y")
.end();
// y = 4 + 0.5 * x1 + 0.2 * x2
Sampler.DataSampleBuilder sampler = new Sampler()
.forColumn("x1").generate((name, index) -> randn() * 0.3 + index)
.forColumn("x2").generate((name, index) -> randn() * 0.3 + index * index)
.forColumn("y").generate((name, index) -> 4 + 0.5 * index + 0.2 * index * index + randn() * 0.3)
.end();
DataFrame trainingData = schema.build();
trainingData = sampler.sample(trainingData, 200);
System.out.println(trainingData.head(10));
DataFrame crossValidationData = schema.build();
crossValidationData = sampler.sample(crossValidationData, 40);
Glm glm = Glm.linear();
glm.setSolverType(GlmSolverType.GlmIrlsQr);
glm.fit(trainingData);
for(int i = 0; i < crossValidationData.rowCount(); ++i){
double predicted = glm.transform(crossValidationData.row(i));
double actual = crossValidationData.row(i).target();
System.out.println("predicted: " + predicted + "\texpected: " + actual);
}
System.out.println("Coefficients: " + glm.getCoefficients());
Sample code for logistic regression
The sample code below performs binary classification using logistic regression:
InputStream inputStream = new FileInputStream("heart_scale.txt");
DataFrame dataFrame = DataQuery.libsvm().from(inputStream).build();
for(int i=0; i < dataFrame.rowCount(); ++i){
DataRow row = dataFrame.row(i);
String targetColumn = row.getTargetColumnNames().get(0);
row.setTargetCell(targetColumn, row.getTargetCell(targetColumn) == -1 ? 0 : 1); // change output from (-1, +1) to (0, 1)
}
TupleTwo<DataFrame, DataFrame> miniFrames = dataFrame.shuffle().split(0.9);
DataFrame trainingData = miniFrames._1();
DataFrame crossValidationData = miniFrames._2();
Glm algorithm = Glm.logistic();
algorithm.setSolverType(GlmSolverType.GlmIrlsQr);
algorithm.fit(trainingData);
double threshold = 1.0;
for(int i = 0; i < trainingData.rowCount(); ++i){
double prob = algorithm.transform(trainingData.row(i));
if(trainingData.row(i).target() == 1 && prob < threshold){
threshold = prob;
}
}
logger.info("threshold: {}",threshold);
BinaryClassifierEvaluator evaluator = new BinaryClassifierEvaluator();
for(int i = 0; i < crossValidationData.rowCount(); ++i){
double prob = algorithm.transform(crossValidationData.row(i));
boolean predicted = prob > 0.5;
boolean actual = crossValidationData.row(i).target() > 0.5;
evaluator.evaluate(actual, predicted);
System.out.println("probability of positive: " + prob);
System.out.println("predicted: " + predicted + "\tactual: " + actual);
}
evaluator.report();
Sample code for multi-class classification
The sample code below perform multi class classification using the logistic regression model as the generator
InputStream irisStream = FileUtils.getResource("iris.data");
DataFrame irisData = DataQuery.csv(",")
.from(irisStream)
.selectColumn(0).asNumeric().asInput("Sepal Length")
.selectColumn(1).asNumeric().asInput("Sepal Width")
.selectColumn(2).asNumeric().asInput("Petal Length")
.selectColumn(3).asNumeric().asInput("Petal Width")
.selectColumn(4).asCategory().asOutput("Iris Type")
.build();
TupleTwo<DataFrame, DataFrame> parts = irisData.shuffle().split(0.9);
DataFrame trainingData = parts._1();
DataFrame crossValidationData = parts._2();
System.out.println(crossValidationData.head(10));
OneVsOneGlmClassifier multiClassClassifier = Glm.oneVsOne(Glm::logistic);
multiClassClassifier.fit(trainingData);
ClassifierEvaluator evaluator = new ClassifierEvaluator();
for(int i=0; i < crossValidationData.rowCount(); ++i) {
String predicted = multiClassClassifier.classify(crossValidationData.row(i));
String actual = crossValidationData.row(i).categoricalTarget();
System.out.println("predicted: " + predicted + "\tactual: " + actual);
evaluator.evaluate(actual, predicted);
}
evaluator.report();
Background on GLM
Introduction
GLM is generalized linear model for exponential family of distribution model b = g(a). g(a) is the inverse link function.
Therefore, for a regressions characterized by inverse link function g(a), the regressions problem be formulated as we are looking for model coefficient set x in
g(A * x) = b + e
And the objective is to find x such for the following objective:
min (g(A * x) - b).transpose * W * (g(A * x) - b)
Suppose we assumes that e consist of uncorrelated naive variables with identical variance, then W = sigma^(-2) * I, and The objective
min (g(A * x) - b) * W * (g(A * x) - b).transpose
is reduced to the OLS form:
min || g(A * x) - b ||^2
Iteratively Re-weighted Least Squares estimation (IRLS)
In regressions, we tried to find a set of model coefficient such for:
A * x = b + e
A * x is known as the model matrix, b as the response vector, e is the error terms.
In OLS (Ordinary Least Square), we assumes that the variance-covariance
matrix V(e) = sigma^2 * W
, where: W is a symmetric positive definite matrix, and is a diagonal matrix sigma is the standard error of e
In OLS (Ordinary Least Square), the objective is to find x_bar such that e.transpose * W * e is minimized (Note that since W is positive definite, e * W * e is alway positive) In other words, we are looking for x_bar such as (A * x_bar - b).transpose * W * (A * x_bar - b) is minimized
Let
y = (A * x - b).transpose * W * (A * x - b)
Now differentiating y with respect to x, we have
dy / dx = A.transpose * W * (A * x - b) * 2
To find min y, set dy / dx = 0 at x = x_bar, we have
A.transpose * W * (A * x_bar - b) = 0
Transform this, we have
A.transpose * W * A * x_bar = A.transpose * W * b
Multiply both side by (A.transpose * W * A).inverse, we have
x_bar = (A.transpose * W * A).inverse * A.transpose * W * b
This is commonly solved using IRLS
The implementation of Glm based on iteratively re-weighted least squares estimation (IRLS)