我在解决一个机器学习问题的数据清理步骤中,试图将长尾中的所有元素归类到一个名为“其他”的共同类别中。例如,我有一个如下所示的数据框:
val df = sc.parallelize(Seq((1, "ABC"),(2, "ABC"),(3, "123"),(4, "FPK"),(5, "FPK"),(6, "ABC"),(7, "ABC"),(8, "980"),(9, "abc"),(10, "FPK"))).toDF("n", "s")
我想保留"ABC"
和"FPK"
类别,因为它们多次出现,但我不想为只出现一次的123,980,abc
设置不同的类别。所以我希望得到的是:
+---+------+| n| s|+---+------+| 1| ABC|| 2| ABC|| 3|Others|| 4| FPK|| 5| FPK|| 6| ABC|| 7| ABC|| 8|Others|| 9|Others|| 10| FPK|+---+------+
为了实现这个目标,我尝试了以下方法:
val newDF = df.withColumn("s",when($"s".isin("123","980","abc"),"Others").otherwise('s)
这样做是有效的。
但我想以编程方式决定哪些类别属于长尾,在我的例子中,就是在原始数据框中只出现一次的类别。所以我编写了以下代码来创建一个只包含这些类别的数据框:
val longTail = df.groupBy("s").agg(count("*").alias("cnt")).orderBy($"cnt".desc).filter($"cnt"<2)+---+---+| s|cnt|+---+---+|980| 1||abc| 1||123| 1|+---+---+
现在我想将这个longTail数据集中“s”列的值转换为列表,以便替换之前硬编码的列表。所以我尝试了以下方法:
val ar = longTail.select("s").collect().map(_(0)).toList
ar: List[Any] = List(123, 980, abc)
但当我尝试添加ar时
val newDF = df.withColumn("s",when($"s".isin(ar),"Others").otherwise('s))
我得到了以下错误:
java.lang.RuntimeException: Unsupported literal type class scala.collection.immutable.$colon$colon List(123, 980, abc)
我错过了什么?
回答:
这是正确的语法:
scala> df.withColumn("s", when($"s".isin(ar : _*), "Others").otherwise('s)).show+---+------+| n| s|+---+------+| 1| ABC|| 2| ABC|| 3|Others|| 4| FPK|| 5| FPK|| 6| ABC|| 7| ABC|| 8|Others|| 9|Others|| 10| FPK|+---+------+
这被称为重复参数。参见这里。