最近在用Spark MLlib进行特征处理时,对于StringIndexer和IndexToString遇到了点问题,查阅官方文档也没有解决疑惑。无奈之下翻看源码才明白其中一二...这就给大家娓娓道来。

更多内容参考我的大数据学习之路

文档说明

StringIndexer 字符串转索引

StringIndexer可以把字符串的列按照出现频率进行排序,出现次数最高的对应的Index为0。比如下面的列表进行StringIndexer

id

category

0

a

1

b

2

c

3

a

4

a

5

c

就可以得到如下:

id

category

categoryIndex

0

a

0.0

1

b

2.0

2

c

1.0

3

a

0.0

4

a

0.0

5

c

1.0

可以看到出现次数最多的"a",索引为0;次数最少的"b"索引为2。

针对训练集中没有出现的字符串值,spark提供了几种处理的方法:

error,直接抛出异常

skip,跳过该样本数据

keep,使用一个新的最大索引,来表示所有未出现的值

下面是基于Spark MLlib 2.2.0的代码样例:

package xingoo.ml.features.tranformer

import org.apache.spark.sql.SparkSession

import org.apache.spark.ml.feature.StringIndexer

object StringIndexerTest {

def main(args: Array[String]): Unit = {

val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate()

spark.sparkContext.setLogLevel("WARN")

val df = spark.createDataFrame(

Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))

).toDF("id", "category")

val df1 = spark.createDataFrame(

Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f"))

).toDF("id", "category")

val indexer = new StringIndexer()

.setInputCol("category")

.setOutputCol("categoryIndex")

.setHandleInvalid("keep") //skip keep error

val model = indexer.fit(df)

val indexed = model.transform(df1)

indexed.show(false)

}

}

得到的结果为:

+---+--------+-------------+

|id |category|categoryIndex|

+---+--------+-------------+

|0 |a |0.0 |

|1 |b |2.0 |

|2 |c |1.0 |

|3 |a |0.0 |

|4 |e |3.0 |

|5 |f |3.0 |

+---+--------+-------------+

IndexToString 索引转字符串

这个索引转回字符串要搭配前面的StringIndexer一起使用才行:

package xingoo.ml.features.tranformer

import org.apache.spark.ml.attribute.Attribute

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}

import org.apache.spark.sql.SparkSession

object IndexToString2 {

def main(args: Array[String]): Unit = {

val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()

spark.sparkContext.setLogLevel("WARN")

val df = spark.createDataFrame(Seq(

(0, "a"),

(1, "b"),

(2, "c"),

(3, "a"),

(4, "a"),

(5, "c")

)).toDF("id", "category")

val indexer = new StringIndexer()

.setInputCol("category")

.setOutputCol("categoryIndex")

.fit(df)

val indexed = indexer.transform(df)

println(s"Transformed string column '${indexer.getInputCol}' " +

s"to indexed column '${indexer.getOutputCol}'")

indexed.show()

val inputColSchema = indexed.schema(indexer.getOutputCol)

println(s"StringIndexer will store labels in output column metadata: " +

s"${Attribute.fromStructField(inputColSchema).toString}\n")

val converter = new IndexToString()

.setInputCol("categoryIndex")

.setOutputCol("originalCategory")

val converted = converter.transform(indexed)

println(s"Transformed indexed column '${converter.getInputCol}' back to original string " +

s"column '${converter.getOutputCol}' using labels in metadata")

converted.select("id", "categoryIndex", "originalCategory").show()

}

}

得到的结果如下:

Transformed string column 'category' to indexed column 'categoryIndex'

+---+--------+-------------+

| id|category|categoryIndex|

+---+--------+-------------+

| 0| a| 0.0|

| 1| b| 2.0|

| 2| c| 1.0|

| 3| a| 0.0|

| 4| a| 0.0|

| 5| c| 1.0|

+---+--------+-------------+

StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}

Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata

+---+-------------+----------------+

| id|categoryIndex|originalCategory|

+---+-------------+----------------+

| 0| 0.0| a|

| 1| 2.0| b|

| 2| 1.0| c|

| 3| 0.0| a|

| 4| 0.0| a|

| 5| 1.0| c|

+---+-------------+----------------+

使用问题

假如处理的过程很复杂,重新生成了一个DataFrame,此时想要把这个DataFrame基于IndexToString转回原来的字符串怎么办呢? 先来试试看:

package xingoo.ml.features.tranformer

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}

import org.apache.spark.sql.SparkSession

object IndexToString3 {

def main(args: Array[String]): Unit = {

val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()

spark.sparkContext.setLogLevel("WARN")

val df = spark.createDataFrame(Seq(

(0, "a"),

(1, "b"),

(2, "c"),

(3, "a"),

(4, "a"),

(5, "c")

)).toDF("id", "category")

val df2 = spark.createDataFrame(Seq(

(0, 2.0),

(1, 1.0),

(2, 1.0),

(3, 0.0)

)).toDF("id", "index")

val indexer = new StringIndexer()

.setInputCol("category")

.setOutputCol("categoryIndex")

.fit(df)

val indexed = indexer.transform(df)

val converter = new IndexToString()

.setInputCol("categoryIndex")

.setOutputCol("originalCategory")

val converted = converter.transform(df2)

converted.show()

}

}

运行后发现异常:

18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint

Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist.

at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)

at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)

at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)

at scala.collection.AbstractMap.getOrElse(Map.scala:59)

at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)

at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338)

at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)

at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352)

at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37)

at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)

这是为什么呢?跟随源码来看吧!

源码剖析

首先我们创建一个DataFrame,获得原始数据:

val df = spark.createDataFrame(Seq(

(0, "a"),

(1, "b"),

(2, "c"),

(3, "a"),

(4, "a"),

(5, "c")

)).toDF("id", "category")

然后创建对应的StringIndexer:

val indexer = new StringIndexer()

.setInputCol("category")

.setOutputCol("categoryIndex")

.setHandleInvalid("skip")

.fit(df)

这里面的fit就是在训练转换器了,进入fit():

override def fit(dataset: Dataset[_]): StringIndexerModel = {

transformSchema(dataset.schema, logging = true)

// 这里针对需要转换的列先强制转换成字符串,然后遍历统计每个字符串出现的次数

val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))

.rdd

.map(_.getString(0))

.countByValue()

// counts是一个map,里面的内容为{a->3, b->1, c->2}

val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray

// 按照个数大小排序,返回数组,[a, c, b]

// 把这个label保存起来,并返回对应的model(mllib里边的模型都是这个套路,跟sklearn学的)

copyValues(new StringIndexerModel(uid, labels).setParent(this))

}

这样就得到了一个列表,列表里面的内容是[a, c, b],然后执行transform来进行转换:

val indexed = indexer.transform(df)

这个transform可想而知就是用这个数组对每一行的该列进行转换,但是它其实还做了其他的事情:

override def transform(dataset: Dataset[_]): DataFrame = {

...

// --------

// 通过label生成一个Metadata,这个很关键!!!

// metadata其实是一个map,内容为:

// {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}}

// --------

val metadata = NominalAttribute.defaultAttr

.withName($(outputCol)).withValues(filteredLabels).toMetadata()

// 如果是skip则过滤一些数据

...

// 下面是针对不同的情况处理转换的列,逻辑很简单

val indexer = udf { label: String =>

...

if (labelToIndex.contains(label)) {

labelToIndex(label) //如果正常,就进行转换

} else if (keepInvalid) {

labels.length // 如果是keep,就返回索引的最大值(即数组的长度)

} else {

... // 如果是error,就抛出异常

}

}

// 保留之前所有的列,新增一个字段,并设置字段的StructField中的Metadata!!!!

// 并设置字段的StructField中的Metadata!!!!

// 并设置字段的StructField中的Metadata!!!!

// 并设置字段的StructField中的Metadata!!!!

filteredDataset.select(col("*"),

indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))

}

看到了吗!关键的地方在这里,给新增加的字段的类型StructField设置了一个Metadata。这个Metadata正常都是空的{},但是这里设置了metadata之后,里面包含了label数组的信息。

接下来看看IndexToString是怎么用的,由于IndexToString是一个Transformer,因此只有一个trasform方法:

override def transform(dataset: Dataset[_]): DataFrame = {

transformSchema(dataset.schema, logging = true)

val inputColSchema = dataset.schema($(inputCol))

// If the labels array is empty use column metadata

// 关键是这里:

// 如果IndexToString设置了labels数组,就直接返回;

// 否则,就读取了传入的DataFrame的StructField中的Metadata

val values = if (!isDefined(labels) || $(labels).isEmpty) {

Attribute.fromStructField(inputColSchema)

.asInstanceOf[NominalAttribute].values.get

} else {

$(labels)

}

// 基于这个values把index转成对应的值

val indexer = udf { index: Double =>

val idx = index.toInt

if (0 <= idx && idx < values.length) {

values(idx)

} else {

throw new SparkException(s"Unseen index: $index ??")

}

}

val outputColName = $(outputCol)

dataset.select(col("*"),

indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))

}

了解StringIndexer和IndexToString的原理机制后,就可以作出如下的应对策略了。

1 增加StructField的MetaData信息

val df2 = spark.createDataFrame(Seq(

(0, 2.0),

(1, 1.0),

(2, 1.0),

(3, 0.0)

)).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata))

val converter = new IndexToString()

.setInputCol("formated_index")

.setOutputCol("origin_col")

val converted = converter.transform(df2)

converted.show(false)

+---+-----+--------------+----------+

|id |index|formated_index|origin_col|

+---+-----+--------------+----------+

|0 |2.0 |2.0 |b |

|1 |1.0 |1.0 |c |

|2 |1.0 |1.0 |c |

|3 |0.0 |0.0 |a |

+---+-----+--------------+----------+

2 获取之前StringIndexer后的DataFrame中的Label信息

val df3 = spark.createDataFrame(Seq(

(0, 2.0),

(1, 1.0),

(2, 1.0),

(3, 0.0)

)).toDF("id", "index")

val converter2 = new IndexToString()

.setInputCol("index")

.setOutputCol("origin_col")

.setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals"))

val converted2 = converter2.transform(df3)

converted2.show(false)

+---+-----+----------+

|id |index|origin_col|

+---+-----+----------+

|0 |2.0 |b |

|1 |1.0 |c |

|2 |1.0 |c |

|3 |0.0 |a |

+---+-----+----------+

两种方法都能得到正确的输出。

完整的代码可以参考github链接:

https://github.com/xinghalo/spark-in-action/blob/master/src/xingoo/ml/features/tranformer/IndexToStringTest.scala

最终还是推荐详细阅读官方文档,不过官方文档真心有些粗糙,想要了解其中的原理,还是得静下心来看看源码。

查看原文