Overview
Java and Scala library for Apache Spark
Library that can obtain feature importance of tree model prediction or classification result with column name in spark.ml.
It is licensed under MIT.
How to use(Java)
Maven
<dependency>
<groupId>org.riversun</groupId>
<artifactId>spark-ml-feature-importance-helper</artifactId>
<version>1.0.0</version>
</dependency>
Example
You can use this library from Java.
// Get model from pipeline stage
GBTRegressionModel gbtModel = (GBTRegressionModel) (pipelineModel.stages()[stageIndex]);
// Do prediction
Dataset<Row> predictions = pipelineModel.transform(testData);
// Get schema from result DataSet
StructType schema = predictions.schema();
// Get sorted feature importances with column name
List<Importance> importanceList =
new FeatureImportance.Builder(gbtModel, schema)
.sort(Order.DESCENDING)
.build()
.getResult();
How To Use(Scala)
build.sbt
libraryDependencies += "org.riversun" % "spark-ml-feature-importance-helper" % "1.0.0"
Example
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
import org.apache.spark.sql.SparkSession
import org.riversun.ml.spark.FeatureImportance
import org.riversun.ml.spark.FeatureImportance.Order
object GradientBoostedTreeRegressorExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("GradientBoostedTreeRegressorExample")
.master("local[*]")
.getOrCreate()
val dataset = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("data/mllib/gem_price.csv") // gem_price_ja.csv for Japanese
val stringIndexers = Array("material", "shape", "brand", "shop")
.map { colName =>
new StringIndexer()
.setInputCol(colName)
.setOutputCol(colName + "Index")
}
val assembler = new VectorAssembler()
.setInputCols(stringIndexers.map(indexer => indexer.getOutputCol) :+ "weight")
.setOutputCol("features")
val gbtr = new GBTRegressor()
.setLabelCol("price")
.setFeaturesCol("features")
.setPredictionCol("prediction")
val pipeline = new Pipeline().setStages(stringIndexers :+ assembler :+ gbtr);
val splits = dataset.randomSplit(Array(0.7, 0.3), 1L)
val trainingData = splits(0)
val testData = splits(1)
val model = pipeline.fit(trainingData)
val predictions = model.transform(testData)
val gbtModel = model.stages.last.asInstanceOf[GBTRegressionModel];
val schema = predictions.schema
val importances = new FeatureImportance.Builder(gbtModel, schema)
.sort(Order.DESCENDING)
.build.getResult
importances.forEach(println)
spark.stop()
}
}
Example result of feature importances
FeatureInfo [rank=0, score=0.35155564557381036, name=weight]
FeatureInfo [rank=1, score=0.23487364413432302, name=brandIndex]
FeatureInfo [rank=2, score=0.22461466434553393, name=materialIndex]
FeatureInfo [rank=3, score=0.09654096046037855, name=shapeIndex]
FeatureInfo [rank=4, score=0.09241508548595412, name=shopIndex]