|
|||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | ||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | ||||||||
Objectorg.apache.spark.mllib.tree.model.GradientBoostedTreesModel
public class GradientBoostedTreesModel
:: Experimental :: Represents a gradient boosted trees model.
param: algo algorithm for the ensemble model, either Classification or Regression param: trees tree ensembles param: treeWeights tree ensemble weights
| Constructor Summary | |
|---|---|
GradientBoostedTreesModel(scala.Enumeration.Value algo,
DecisionTreeModel[] trees,
double[] treeWeights)
|
|
| Method Summary | |
|---|---|
scala.Enumeration.Value |
algo()
|
static RDD<scala.Tuple2<Object,Object>> |
computeInitialPredictionAndError(RDD<LabeledPoint> data,
double initTreeWeight,
DecisionTreeModel initTree,
Loss loss)
Compute the initial predictions and errors for a dataset for the first iteration of gradient boosting. |
double[] |
evaluateEachIteration(RDD<LabeledPoint> data,
Loss loss)
Method to compute error or loss for every iteration of gradient boosting. |
static GradientBoostedTreesModel |
load(SparkContext sc,
String path)
|
int |
numTrees()
Get number of trees in ensemble. |
JavaRDD<Double> |
predict(JavaRDD<Vector> features)
Java-friendly version of TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector). |
RDD<Object> |
predict(RDD<Vector> features)
Predict values for the given data set. |
double |
predict(Vector features)
Predict values for a single data point using the model trained. |
void |
save(SparkContext sc,
String path)
Save this model to the given path. |
String |
toDebugString()
Print the full model to a string. |
String |
toString()
Print a summary of the model. |
int |
totalNumNodes()
Get total number of nodes, summed over all trees in the ensemble. |
DecisionTreeModel[] |
trees()
|
double[] |
treeWeights()
|
static RDD<scala.Tuple2<Object,Object>> |
updatePredictionError(RDD<LabeledPoint> data,
RDD<scala.Tuple2<Object,Object>> predictionAndError,
double treeWeight,
DecisionTreeModel tree,
Loss loss)
Update a zipped predictionError RDD (as obtained with computeInitialPredictionAndError) |
| Methods inherited from class Object |
|---|
equals, getClass, hashCode, notify, notifyAll, wait, wait, wait |
| Constructor Detail |
|---|
public GradientBoostedTreesModel(scala.Enumeration.Value algo,
DecisionTreeModel[] trees,
double[] treeWeights)
| Method Detail |
|---|
public static RDD<scala.Tuple2<Object,Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data,
double initTreeWeight,
DecisionTreeModel initTree,
Loss loss)
data: - training data.initTreeWeight: - learning rate assigned to the first tree.initTree: - first DecisionTreeModel.loss: - evaluation metric.
public static RDD<scala.Tuple2<Object,Object>> updatePredictionError(RDD<LabeledPoint> data,
RDD<scala.Tuple2<Object,Object>> predictionAndError,
double treeWeight,
DecisionTreeModel tree,
Loss loss)
data: - training data.predictionAndError: - predictionError RDDtreeWeight: - Learning rate.tree: - Tree using which the prediction and error should be updated.loss: - evaluation metric.
public static GradientBoostedTreesModel load(SparkContext sc,
String path)
public scala.Enumeration.Value algo()
public DecisionTreeModel[] trees()
public double[] treeWeights()
public void save(SparkContext sc,
String path)
SaveableThis saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load.
save in interface Saveablesc - Spark context used to save model data.path - Path specifying the directory in which to save this model.
If the directory already exists, this method throws an exception.
public double[] evaluateEachIteration(RDD<LabeledPoint> data,
Loss loss)
data - RDD of LabeledPointloss - evaluation metric.
public double predict(Vector features)
features - array representing a single data point
public RDD<Object> predict(RDD<Vector> features)
features - RDD representing data points to be predicted
public JavaRDD<Double> predict(JavaRDD<Vector> features)
TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector).
features - (undocumented)
public String toString()
toString in class Objectpublic String toDebugString()
public int numTrees()
public int totalNumNodes()
|
|||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | ||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | ||||||||