了解StringIndexer和IndexToString的原理机制后,就可以作出如下的应对策略了。
1 增加StructField的MetaData信息 val df2 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata)) val converter = new IndexToString() .setInputCol("formated_index") .setOutputCol("origin_col") val converted = converter.transform(df2) converted.show(false) +---+-----+--------------+----------+ |id |index|formated_index|origin_col| +---+-----+--------------+----------+ |0 |2.0 |2.0 |b | |1 |1.0 |1.0 |c | |2 |1.0 |1.0 |c | |3 |0.0 |0.0 |a | +---+-----+--------------+----------+ 2 获取之前StringIndexer后的DataFrame中的Label信息 val df3 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index") val converter2 = new IndexToString() .setInputCol("index") .setOutputCol("origin_col") .setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals")) val converted2 = converter2.transform(df3) converted2.show(false) +---+-----+----------+ |id |index|origin_col| +---+-----+----------+ |0 |2.0 |b | |1 |1.0 |c | |2 |1.0 |c | |3 |0.0 |a | +---+-----+----------+两种方法都能得到正确的输出。
完整的代码可以参考github链接:
https://github.com/xinghalo/spark-in-action/blob/master/src/xingoo/ml/features/tranformer/IndexToStringTest.scala