数据集介绍:Qualitative_Bankruptcy Data Set
数据集下载链接:下载
该数据集包含6个特征和一个Label,均为定性描述
1. Industrial Risk: {P,A,N} 2. Management Risk: {P,A,N} 3. Financial Flexibility: {P,A,N} 4. Credibility: {P,A,N} 5. Competitiveness: {P,A,N} 6. Operating Risk: {P,A,N} 7. Class: {B,NB}
(P=Positive,A-Average,N-negative,B-Bankruptcy,NB-Non-Bankruptcy)
通过建立前6个特征与标签的随机森林模型,用于预测是否破产
下载下来的数据样本:
P,P,A,A,A,P,NB N,N,A,A,A,N,NB A,A,A,A,A,A,NB P,P,P,P,P,P,NB N,N,P,P,P,N,NB
为了便于spark直接生成DataFrame,我们给数据集添加头部:
irisk,mrisk,fina,cred,comp,orisk,label P,P,A,A,A,P,NB N,N,A,A,A,N,NB A,A,A,A,A,A,NB P,P,P,P,P,P,NB N,N,P,P,P,N,NB
使用Spark随机森林分类器包括以下几个步骤:
数据读取-->数据预处理-->特征转换-->特征选择-->模型训练-->模型效果评估
完整示例代码:
import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.RandomForestClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler} import org.apache.spark.sql.SparkSession object RandomForestClassifierTest { def main(args: Array[String]): Unit = { //创建SparkSession val spark = SparkSession .builder() .appName("RandomForestClassifierTest") .master("local[4]") .getOrCreate() //读取csv格式数据,设置header=true表示包含表头 val oriData = spark.read.option("header", "true") .csv("~/Qualitative_Bankruptcy.data.txt") //打印数据的Schema: // root // |-- irisk: string (nullable = true) // |-- mrisk: string (nullable = true) // |-- fina: string (nullable = true) // |-- cred: string (nullable = true) // |-- comp: string (nullable = true) // |-- orisk: string (nullable = true) // |-- label: string (nullable = true) oriData.printSchema() //显示前5条样本数据: // +-----+-----+----+----+----+-----+-----+ // |irisk|mrisk|fina|cred|comp|orisk|label| // +-----+-----+----+----+----+-----+-----+ // |P |P |A |A |A |P |NB | // |N |N |A |A |A |N |NB | // |A |A |A |A |A |A |NB | // |P |P |P |P |P |P |NB | // |N |N |P |P |P |N |NB | // +-----+-----+----+----+----+-----+-----+ oriData.show(5, false) //特征预处理:将字符串类型转换成数字类型 val indexedIrisk = new StringIndexer().setInputCol("irisk").setOutputCol("indexedIrisk").fit(oriData) val indexedMrisk = new StringIndexer().setInputCol("mrisk").setOutputCol("indexedMrisk").fit(oriData) val indexedFina = new StringIndexer().setInputCol("fina").setOutputCol("indexedFina").fit(oriData) val indexedCred = new StringIndexer().setInputCol("cred").setOutputCol("indexedCred").fit(oriData) val indexedComp = new StringIndexer().setInputCol("comp").setOutputCol("indexedComp").fit(oriData) val indexedOrisk = new StringIndexer().setInputCol("orisk").setOutputCol("indexedOrisk").fit(oriData) //将所有的特征组合成一个Vector val featuresAssembler = new VectorAssembler().setInputCols(Array("indexedIrisk", "indexedMrisk", "indexedFina", "indexedCred", "indexedComp", "indexedOrisk")).setOutputCol("features") //标签转换:将字符串型类别转换成数字类别 val indexedLabel = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(oriData) //将各种数据预处理连接成管道,依次顺序执行 val pipeline = new Pipeline().setStages(Array(indexedIrisk, indexedMrisk, indexedFina, indexedCred, indexedComp, indexedOrisk, featuresAssembler, indexedLabel)) val preModel = pipeline.fit(oriData) val preData = preModel.transform(oriData) //将预处理的数据进行切分:70%用于训练随机森林模型,30%用于测试训练的模型效果 val Array(trainData, testData) = preData.randomSplit(Array(0.7, 0.3)) //创建随机森林分类器 val rf = new RandomForestClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("features") .setNumTrees(10) //训练随机森林模型 val rfModel = rf.fit(trainData) //使用训练的模型,对测试样本数据进行预测 val predictions = rfModel.transform(testData) //创建分类效果评估器,此处我们只关注“准确率”评估指标 val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy") //对预测数据进行评估: // accuracy=1.0 val accuracy = evaluator.evaluate(predictions) println("accuracy=" + accuracy) //打印随机森林分类器 println(rfModel.toDebugString) spark.stop() } }
最终构建的随机森林模型:
RandomForestClassificationModel (uid=rfc_254bd6468fa2) with 10 trees Tree 0 (weight 1.0): If (feature 4 in {1.0,2.0}) If (feature 3 in {1.0,2.0}) Predict: 0.0 Else (feature 3 not in {1.0,2.0}) If (feature 2 in {1.0,2.0}) Predict: 0.0 Else (feature 2 not in {1.0,2.0}) Predict: 1.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Tree 1 (weight 1.0): If (feature 4 in {1.0,2.0}) If (feature 3 in {1.0,2.0}) Predict: 0.0 Else (feature 3 not in {1.0,2.0}) If (feature 5 in {0.0,1.0}) Predict: 0.0 Else (feature 5 not in {0.0,1.0}) Predict: 1.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Tree 2 (weight 1.0): If (feature 4 in {1.0,2.0}) If (feature 5 in {0.0,1.0}) Predict: 0.0 Else (feature 5 not in {0.0,1.0}) If (feature 0 in {0.0,1.0}) Predict: 0.0 Else (feature 0 not in {0.0,1.0}) If (feature 1 in {1.0}) Predict: 0.0 Else (feature 1 not in {1.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Tree 3 (weight 1.0): If (feature 4 in {1.0,2.0}) If (feature 2 in {1.0,2.0}) Predict: 0.0 Else (feature 2 not in {1.0,2.0}) If (feature 3 in {1.0,2.0}) Predict: 0.0 Else (feature 3 not in {1.0,2.0}) Predict: 1.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Tree 4 (weight 1.0): If (feature 3 in {1.0,2.0}) If (feature 4 in {1.0,2.0}) Predict: 0.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Else (feature 3 not in {1.0,2.0}) If (feature 1 in {1.0,2.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 1 not in {1.0,2.0}) If (feature 4 in {2.0}) If (feature 2 in {1.0}) Predict: 0.0 Else (feature 2 not in {1.0}) Predict: 1.0 Else (feature 4 not in {2.0}) Predict: 1.0 Tree 5 (weight 1.0): If (feature 3 in {1.0,2.0}) If (feature 2 in {1.0,2.0}) Predict: 0.0 Else (feature 2 not in {1.0,2.0}) If (feature 4 in {1.0,2.0}) Predict: 0.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Else (feature 3 not in {1.0,2.0}) If (feature 0 in {1.0,2.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 0 not in {1.0,2.0}) Predict: 1.0 Tree 6 (weight 1.0): If (feature 4 in {1.0,2.0}) If (feature 3 in {1.0,2.0}) Predict: 0.0 Else (feature 3 not in {1.0,2.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) If (feature 5 in {1.0}) Predict: 0.0 Else (feature 5 not in {1.0}) Predict: 1.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Tree 7 (weight 1.0): If (feature 2 in {1.0,2.0}) If (feature 0 in {1.0,2.0}) Predict: 0.0 Else (feature 0 not in {1.0,2.0}) If (feature 4 in {1.0,2.0}) Predict: 0.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Else (feature 2 not in {1.0,2.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Tree 8 (weight 1.0): If (feature 2 in {1.0,2.0}) If (feature 0 in {1.0,2.0}) Predict: 0.0 Else (feature 0 not in {1.0,2.0}) If (feature 1 in {1.0,2.0}) Predict: 0.0 Else (feature 1 not in {1.0,2.0}) If (feature 4 in {1.0,2.0}) Predict: 0.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0 Else (feature 2 not in {1.0,2.0}) If (feature 3 in {1.0}) If (feature 0 in {1.0,2.0}) Predict: 0.0 Else (feature 0 not in {1.0,2.0}) Predict: 1.0 Else (feature 3 not in {1.0}) If (feature 3 in {2.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 3 not in {2.0}) Predict: 1.0 Tree 9 (weight 1.0): If (feature 2 in {1.0,2.0}) Predict: 0.0 Else (feature 2 not in {1.0,2.0}) If (feature 4 in {1.0,2.0}) If (feature 3 in {1.0,2.0}) Predict: 0.0 Else (feature 3 not in {1.0,2.0}) Predict: 1.0 Else (feature 4 not in {1.0,2.0}) Predict: 1.0