Spark MLlib特征处理 之 StringIndexer、IndexToString使用说明以及源码剖析 (3)

了解StringIndexer和IndexToString的原理机制后,就可以作出如下的应对策略了。

1 增加StructField的MetaData信息 val df2 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata)) val converter = new IndexToString() .setInputCol("formated_index") .setOutputCol("origin_col") val converted = converter.transform(df2) converted.show(false) +---+-----+--------------+----------+ |id |index|formated_index|origin_col| +---+-----+--------------+----------+ |0 |2.0 |2.0 |b | |1 |1.0 |1.0 |c | |2 |1.0 |1.0 |c | |3 |0.0 |0.0 |a | +---+-----+--------------+----------+ 2 获取之前StringIndexer后的DataFrame中的Label信息 val df3 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index") val converter2 = new IndexToString() .setInputCol("index") .setOutputCol("origin_col") .setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals")) val converted2 = converter2.transform(df3) converted2.show(false) +---+-----+----------+ |id |index|origin_col| +---+-----+----------+ |0 |2.0 |b | |1 |1.0 |c | |2 |1.0 |c | |3 |0.0 |a | +---+-----+----------+

两种方法都能得到正确的输出。

完整的代码可以参考github链接:

https://github.com/xinghalo/spark-in-action/blob/master/src/xingoo/ml/features/tranformer/IndexToStringTest.scala

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpysgj.html