Spark mllib教程
直接上代码:
package org.apache.spark.examples.sql import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SparkSession} import scala.util.Random object RegExample { def main(args: Array[String]) { var isOnline = false var spark: SparkSession = SparkSession .builder .appName("Count") .master("local[4]") .getOrCreate() var sc = spark.sparkContext val schema = StructType(List( StructField("x1", DoubleType, nullable = false), StructField("x2", DoubleType, nullable = true), StructField("x3", DoubleType, nullable = true), StructField("y", DoubleType, nullable = true) )) var data = createData.map(x => { var y = x.split(",") Row(y(0).toDouble, y(1).toDouble, y(2).toDouble, y(3).toDouble) }) var rdd = sc.makeRDD(data) var df = spark.createDataFrame(rdd, schema) var Array(trainData, testData) = df.randomSplit(Array(0.9, 0.1)) val assembler = new VectorAssembler() .setInputCols(Array("x1", "x3")) .setOutputCol("features") //训练模型 val trainData1 = assembler.transform(trainData) val lr = new LinearRegression().setFeaturesCol("features").setLabelCol("y").setFitIntercept(true) val model = lr.fit(trainData1) // 输出模型全部参数 println(model.extractParamMap()) println(s"系数: ${model.coefficients} ,截距: ${model.intercept}") // 模型进行评价 val trainingSummary = model.summary println(s"numIterations: ${trainingSummary.totalIterations}") println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") trainingSummary.residuals.show() println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") println(s"r2: ${trainingSummary.r2}") //预测数据 val testData1 = assembler.transform(testData) var predict = model.transform(testData1) predict.select("y", "prediction").show(false) spark.stop() } def createData: Array[String] = { var rand = new Random(System.currentTimeMillis()) var data = new Array[String](100) for (i <- 1 to 100) { var x1 = (1 + rand.nextDouble() * 5).formatted("%.2f").toDouble var x2 = (10 + rand.nextDouble() * 10).formatted("%.2f").toDouble var x3 = (1 + rand.nextDouble() * 20).formatted("%.2f").toDouble var y = (x1 + x3 * 0.5).formatted("%.2f").toDouble data(i - 1) = x1 + "," + x2 + "," + x3 + "," + y; } data } }
待续