spark-ml-feature-importance-helper

Library that can obtain feature importance of tree model prediction or classification result with column name in spark.ml

License

License

GroupId

GroupId

org.riversun
ArtifactId

ArtifactId

spark-ml-feature-importance-helper
Last Version

Last Version

1.0.0
Release Date

Release Date

Type

Type

jar
Description

Description

spark-ml-feature-importance-helper
Library that can obtain feature importance of tree model prediction or classification result with column name in spark.ml
Project URL

Project URL

https://github.com/riversun/spark-ml-feature-importance-helper
Source Code Management

Source Code Management

https://github.com/riversun/spark-ml-feature-importance-helper

Download spark-ml-feature-importance-helper

How to add to project

<!-- https://jarcasting.com/artifacts/org.riversun/spark-ml-feature-importance-helper/ -->
<dependency>
    <groupId>org.riversun</groupId>
    <artifactId>spark-ml-feature-importance-helper</artifactId>
    <version>1.0.0</version>
</dependency>
// https://jarcasting.com/artifacts/org.riversun/spark-ml-feature-importance-helper/
implementation 'org.riversun:spark-ml-feature-importance-helper:1.0.0'
// https://jarcasting.com/artifacts/org.riversun/spark-ml-feature-importance-helper/
implementation ("org.riversun:spark-ml-feature-importance-helper:1.0.0")
'org.riversun:spark-ml-feature-importance-helper:jar:1.0.0'
<dependency org="org.riversun" name="spark-ml-feature-importance-helper" rev="1.0.0">
  <artifact name="spark-ml-feature-importance-helper" type="jar" />
</dependency>
@Grapes(
@Grab(group='org.riversun', module='spark-ml-feature-importance-helper', version='1.0.0')
)
libraryDependencies += "org.riversun" % "spark-ml-feature-importance-helper" % "1.0.0"
[org.riversun/spark-ml-feature-importance-helper "1.0.0"]

Dependencies

compile (2)

Group / Artifact Type Version
org.apache.spark : spark-core_2.12 jar 2.4.3
org.apache.spark : spark-mllib_2.12 jar 2.4.3

test (1)

Group / Artifact Type Version
junit : junit jar 4.7

Project Modules

There are no modules declared in this project.

Overview

Java and Scala library for Apache Spark

Library that can obtain feature importance of tree model prediction or classification result with column name in spark.ml.

It is licensed under MIT.

How to use(Java)

Maven

<dependency>
	<groupId>org.riversun</groupId>
	<artifactId>spark-ml-feature-importance-helper</artifactId>
	<version>1.0.0</version>
</dependency>

Example

You can use this library from Java.

// Get model from pipeline stage
GBTRegressionModel gbtModel = (GBTRegressionModel) (pipelineModel.stages()[stageIndex]);

// Do prediction
Dataset<Row> predictions = pipelineModel.transform(testData);

// Get schema from result DataSet
StructType schema = predictions.schema();

// Get sorted feature importances with column name
List<Importance> importanceList =
       new FeatureImportance.Builder(gbtModel, schema)
         .sort(Order.DESCENDING)
         .build()
         .getResult();

How To Use(Scala)

build.sbt

libraryDependencies += "org.riversun" % "spark-ml-feature-importance-helper" % "1.0.0"

Example

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
import org.apache.spark.sql.SparkSession
import org.riversun.ml.spark.FeatureImportance
import org.riversun.ml.spark.FeatureImportance.Order

object GradientBoostedTreeRegressorExample {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession
      .builder
      .appName("GradientBoostedTreeRegressorExample")
      .master("local[*]")
      .getOrCreate()

    val dataset = spark.read.format("csv")
      .option("header", "true")
      .option("inferSchema", "true")
      .load("data/mllib/gem_price.csv") // gem_price_ja.csv for Japanese

    val stringIndexers = Array("material", "shape", "brand", "shop")
      .map { colName =>
        new StringIndexer()
          .setInputCol(colName)
          .setOutputCol(colName + "Index")
      }

    val assembler = new VectorAssembler()
      .setInputCols(stringIndexers.map(indexer => indexer.getOutputCol) :+ "weight")
      .setOutputCol("features")

    val gbtr = new GBTRegressor()
      .setLabelCol("price")
      .setFeaturesCol("features")
      .setPredictionCol("prediction")

    val pipeline = new Pipeline().setStages(stringIndexers :+ assembler :+ gbtr);

    val splits = dataset.randomSplit(Array(0.7, 0.3), 1L)
    val trainingData = splits(0)
    val testData = splits(1)

    val model = pipeline.fit(trainingData)

    val predictions = model.transform(testData)

    val gbtModel = model.stages.last.asInstanceOf[GBTRegressionModel];
    val schema = predictions.schema

    val importances = new FeatureImportance.Builder(gbtModel, schema)
      .sort(Order.DESCENDING)
      .build.getResult

    importances.forEach(println)

    spark.stop()
  }
}

Example result of feature importances

FeatureInfo [rank=0, score=0.35155564557381036, name=weight]
FeatureInfo [rank=1, score=0.23487364413432302, name=brandIndex]
FeatureInfo [rank=2, score=0.22461466434553393, name=materialIndex]
FeatureInfo [rank=3, score=0.09654096046037855, name=shapeIndex]
FeatureInfo [rank=4, score=0.09241508548595412, name=shopIndex]

Versions

Version
1.0.0