Spark MLlib特征处理 之 StringIndexer、IndexToString使用说明以及源码剖析

最近在用Spark MLlib进行特征处理时,对于StringIndexer和IndexToString遇到了点问题,查阅官方文档也没有解决疑惑。无奈之下翻看源码才明白其中一二...这就给大家娓娓道来。

文档说明 StringIndexer 字符串转索引

StringIndexer可以把字符串的列按照出现频率进行排序,出现次数最高的对应的Index为0。比如下面的列表进行StringIndexer
id | category
----|----------
0 | a
1 | b
2 | c
3 | a
4 | a
5 | c
就可以得到如下:
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
2 | c | 1.0
3 | a | 0.0
4 | a | 0.0
5 | c | 1.0
可以看到出现次数最多的"a",索引为0;次数最少的"b"索引为2。

针对训练集中没有出现的字符串值,spark提供了几种处理的方法:

error,直接抛出异常

skip,跳过该样本数据

keep,使用一个新的最大索引,来表示所有未出现的值

下面是基于Spark MLlib 2.2.0的代码样例:

package xingoo.ml.features.tranformer import org.apache.spark.sql.SparkSession import org.apache.spark.ml.feature.StringIndexer object StringIndexerTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate() spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame( Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) ).toDF("id", "category") val df1 = spark.createDataFrame( Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f")) ).toDF("id", "category") val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .setHandleInvalid("keep") //skip keep error val model = indexer.fit(df) val indexed = model.transform(df1) indexed.show(false) } }

得到的结果为:

+---+--------+-------------+ |id |category|categoryIndex| +---+--------+-------------+ |0 |a |0.0 | |1 |b |2.0 | |2 |c |1.0 | |3 |a |0.0 | |4 |e |3.0 | |5 |f |3.0 | +---+--------+-------------+ IndexToString 索引转字符串

这个索引转回字符串要搭配前面的StringIndexer一起使用才行:

package xingoo.ml.features.tranformer import org.apache.spark.ml.attribute.Attribute import org.apache.spark.ml.feature.{IndexToString, StringIndexer} import org.apache.spark.sql.SparkSession object IndexToString2 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate() spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c") )).toDF("id", "category") val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .fit(df) val indexed = indexer.transform(df) println(s"Transformed string column '${indexer.getInputCol}' " + s"to indexed column '${indexer.getOutputCol}'") indexed.show() val inputColSchema = indexed.schema(indexer.getOutputCol) println(s"StringIndexer will store labels in output column metadata: " + s"${Attribute.fromStructField(inputColSchema).toString}\n") val converter = new IndexToString() .setInputCol("categoryIndex") .setOutputCol("originalCategory") val converted = converter.transform(indexed) println(s"Transformed indexed column '${converter.getInputCol}' back to original string " + s"column '${converter.getOutputCol}' using labels in metadata") converted.select("id", "categoryIndex", "originalCategory").show() } }

得到的结果如下:

Transformed string column 'category' to indexed column 'categoryIndex' +---+--------+-------------+ | id|category|categoryIndex| +---+--------+-------------+ | 0| a| 0.0| | 1| b| 2.0| | 2| c| 1.0| | 3| a| 0.0| | 4| a| 0.0| | 5| c| 1.0| +---+--------+-------------+ StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"} Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata +---+-------------+----------------+ | id|categoryIndex|originalCategory| +---+-------------+----------------+ | 0| 0.0| a| | 1| 2.0| b| | 2| 1.0| c| | 3| 0.0| a| | 4| 0.0| a| | 5| 1.0| c| +---+-------------+----------------+ 使用问题

假如处理的过程很复杂,重新生成了一个DataFrame,此时想要把这个DataFrame基于IndexToString转回原来的字符串怎么办呢? 先来试试看:

package xingoo.ml.features.tranformer import org.apache.spark.ml.feature.{IndexToString, StringIndexer} import org.apache.spark.sql.SparkSession object IndexToString3 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate() spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c") )).toDF("id", "category") val df2 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index") val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .fit(df) val indexed = indexer.transform(df) val converter = new IndexToString() .setInputCol("categoryIndex") .setOutputCol("originalCategory") val converted = converter.transform(df2) converted.show() } }

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpysgj.html