码迷,mamicode.com
首页 > 其他好文 > 详细

Spark Random Forest classifier 随机森林分类

时间:2020-03-04 12:28:12      阅读:105      评论:0      收藏:0      [点我收藏+]

标签:uil   表现力   采样   tree   oms   ram   home   port   call   

1、概述

随机森林是决策树的集合。随机森林是用于分类和回归的最成功的机器学习模型之一。他们结合了许多决策树,以减少过度拟合的风险。像决策树一样,随机森林处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。

spark.mllib支持使用连续和分类功能对二元和多类分类以及进行回归的随机森林。

基础算法

随机森林分别训练一组决策树,因此可以并行进行训练。该算法将随机性注入训练过程中,因此每个决策树都略有不同。合并来自每棵树的预测可以减少预测的方差,从而提高测试数据的性能。

训练

注入训练过程的随机性包括:

    在每次迭代中对原始数据集进行二次采样以获得不同的训练集(也称为自举)。
    考虑要在每个树节点上分割的要素的不同随机子集。

除了这些随机化之外,决策树训练的方式与单个决策树的训练方式相同。

参数

    numTrees:森林中的树木数量。
        增加树的数量将减少预测的方差,从而提高模型的测试时间准确性。
        训练时间在树木数量上大致呈线性增加。
    maxDepth:森林中每棵树的最大深度。
        深度的增加使模型更具表现力和功能。但是,深树需要更长的训练时间,也更容易过度拟合。
        通常,使用随机森林比使用单个决策树训练更深的树是可以接受的。与随机森林相比,一棵树更可能过度拟合(由于对森林中的多棵树进行平均而减少了方差)。

2、code

package com.home.spark.ml

import org.apache.spark.SparkConf
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession


object Ex_RandomForests {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf(true).setMaster("local[2]").setAppName("spark ml")
    val spark = SparkSession.builder().config(conf).getOrCreate()

    //rdd转换成df或者ds需要SparkSession实例的隐式转换
    //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
    import spark.implicits._

    // Load and parse the data file, converting it to a DataFrame.
//    val data = spark.read.format("libsvm").load("input/sample_libsvm_data.txt")

    val rawData = spark.sparkContext.textFile("input/iris.data.txt")
      .map(_.split(","))
      .map(a=>Iris(
        Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble),
        a(4))).toDF()

    rawData.createOrReplaceTempView("iris")
    val data = spark.sql("select * from iris")

    // Index labels, adding metadata to the label column.
    // Fit on whole dataset to include all labels in index.
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(data)
    // Automatically identify categorical features, and index them.
    // Set maxCategories so features with > 4 distinct values are treated as continuous.
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(data)

    // Split the data into training and test sets (30% held out for testing).
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

    // Train a RandomForest model.
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setNumTrees(10)

    // Convert indexed labels back to original labels.
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    // Chain indexers and forest in a Pipeline.
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

    // Train model. This also runs the indexers.
    val model = pipeline.fit(trainingData)

    // Make predictions.
    val predictions = model.transform(testData)

    // Select example rows to display.
    predictions.select("predictedLabel", "label", "features").show(30,false)

    // Select (prediction, true label) and compute test error.
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println("Test Error = " + (1.0 - accuracy))

    val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
    println("Learned classification forest model:\n" + rfModel.toDebugString)



  }

}

 

Spark Random Forest classifier 随机森林分类

标签:uil   表现力   采样   tree   oms   ram   home   port   call   

原文地址:https://www.cnblogs.com/asker009/p/12408593.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!