Spark mllib教程
作者: 时海
线性回归

直接上代码:

package org.apache.spark.examples.sql

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

import scala.util.Random


object RegExample {

  def main(args: Array[String]) {
    var isOnline = false
    var spark: SparkSession = SparkSession
      .builder
      .appName("Count")
      .master("local[4]")
      .getOrCreate()


    var sc = spark.sparkContext


    val schema = StructType(List(
      StructField("x1", DoubleType, nullable = false),
      StructField("x2", DoubleType, nullable = true),
      StructField("x3", DoubleType, nullable = true),
      StructField("y", DoubleType, nullable = true)
    ))



    var data = createData.map(x => {
      var y = x.split(",")
      Row(y(0).toDouble, y(1).toDouble, y(2).toDouble, y(3).toDouble)
    })


    var rdd = sc.makeRDD(data)

    var df = spark.createDataFrame(rdd, schema)


    var Array(trainData, testData) = df.randomSplit(Array(0.9, 0.1))


    val assembler = new VectorAssembler()
      .setInputCols(Array("x1", "x3"))
      .setOutputCol("features")



    //训练模型
    val trainData1 = assembler.transform(trainData)
    val lr = new LinearRegression().setFeaturesCol("features").setLabelCol("y").setFitIntercept(true)
    val model = lr.fit(trainData1)

    // 输出模型全部参数
    println(model.extractParamMap())
    println(s"系数: ${model.coefficients} ,截距: ${model.intercept}")

    // 模型进行评价
    val trainingSummary = model.summary
    println(s"numIterations: ${trainingSummary.totalIterations}")
    println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
    trainingSummary.residuals.show()
    println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
    println(s"r2: ${trainingSummary.r2}")


    //预测数据
    val testData1 = assembler.transform(testData)
    var predict = model.transform(testData1)
    predict.select("y", "prediction").show(false)


    spark.stop()
  }

  def createData: Array[String] = {
    var rand = new Random(System.currentTimeMillis())

    var data = new Array[String](100)

    for (i <- 1 to 100) {
      var x1 = (1 + rand.nextDouble() * 5).formatted("%.2f").toDouble
      var x2 = (10 + rand.nextDouble() * 10).formatted("%.2f").toDouble
      var x3 = (1 + rand.nextDouble() * 20).formatted("%.2f").toDouble
      var y = (x1 + x3 * 0.5).formatted("%.2f").toDouble

      data(i - 1) = x1 + "," + x2 + "," + x3 + "," + y;


    }

    data
  }


}

待续

标签: var、todouble、trainingsummary、val、x3
一个创业中的苦逼程序员
  • 回复
隐藏