Spark mllib教程
作者: 时海
数据离散化

将连续性数据转换成离散型数据:


import org.apache.spark.ml.feature.Bucketizer
import org.apache.spark.sql.SparkSession

object Test {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Test")
      .master("local[4]")
      .getOrCreate()

    val sc = spark.sparkContext
    
    // 分为10组:[0,1),[1,2),[2,3)...[9,10)
    val splits = Array(0.0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

    //随机生成10个[0,10]之间的数字
    val data: Array[Double] = new Array[Double](10)

    for (i <- 0 until data.length) {
      data(i) = scala.util.Random.nextDouble() * 10
    }

    val dataFrame = spark.createDataFrame(data.map(Tuple1.apply)).toDF("oriData")

    val bucketizer = new Bucketizer()
      .setInputCol("oriData")
      .setOutputCol("group")
      .setSplits(splits)

    // 将原始数据转换为桶索引
    val group = bucketizer.transform(dataFrame)

    group.show(10, truncate = false)
  }
}

输出结果:
+------------------+-----+
|oriData           |group|
+------------------+-----+
|4.236412924972437 |4.0  |
|9.679094094908889 |9.0  |
|5.0415969197782875|5.0  |
|6.8388073417251585|6.0  |
|2.808077999205735 |2.0  |
|7.441193843816235 |7.0  |
|3.2421799004124985|3.0  |
|8.438279370242505 |8.0  |
|3.8095708765069567|3.0  |
|7.391096562456897 |7.0  |
+------------------+-----+


标签: val、bucketizer、oridata、group、spark
一个创业中的苦逼程序员
  • 回复
隐藏