public class NaiveBayesModel extends ProbabilisticClassificationModel<Vector,NaiveBayesModel> implements MLWritable
NaiveBayes
param: pi log of class priors, whose dimension is C (number of classes)
param: theta log of class conditional probabilities, whose dimension is C (number of classes)
by D (number of features)| Modifier and Type | Method and Description |
|---|---|
protected static <T> T |
$(Param<T> param) |
static Params |
clear(Param<?> param) |
NaiveBayesModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
protected static <T extends Params> |
copyValues(T to,
ParamMap extra) |
protected static <T extends Params> |
copyValues$default$2() |
protected static <T extends Params> |
defaultCopy(ParamMap extra) |
static java.lang.String |
explainParam(Param<?> param) |
static java.lang.String |
explainParams() |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<java.lang.String> |
featuresCol() |
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
protected static DataType |
featuresDataType() |
static <T> scala.Option<T> |
get(Param<T> param) |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static java.lang.String |
getFeaturesCol() |
java.lang.String |
getFeaturesCol() |
static java.lang.String |
getLabelCol() |
java.lang.String |
getLabelCol() |
static java.lang.String |
getModelType() |
java.lang.String |
getModelType() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<java.lang.Object> |
getParam(java.lang.String paramName) |
static java.lang.String |
getPredictionCol() |
java.lang.String |
getPredictionCol() |
static java.lang.String |
getProbabilityCol() |
static java.lang.String |
getRawPredictionCol() |
java.lang.String |
getRawPredictionCol() |
static double |
getSmoothing() |
double |
getSmoothing() |
static double[] |
getThresholds() |
static <T> boolean |
hasDefault(Param<T> param) |
static boolean |
hasParam(java.lang.String paramName) |
static boolean |
hasParent() |
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
protected static boolean |
isTraceEnabled() |
static Param<java.lang.String> |
labelCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
static NaiveBayesModel |
load(java.lang.String path) |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static Param<java.lang.String> |
modelType() |
Param<java.lang.String> |
modelType()
The model type which is a string (case-sensitive).
|
int |
numClasses() |
int |
numFeatures() |
static Param<?>[] |
params() |
static void |
parent_$eq(Estimator<M> x$1) |
static Estimator<M> |
parent() |
Vector |
pi() |
protected static double |
predict(FeaturesType features) |
static Param<java.lang.String> |
predictionCol() |
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
protected static Vector |
predictProbability(FeaturesType features) |
protected Vector |
predictRaw(Vector features) |
protected static double |
probability2prediction(Vector probability) |
static Param<java.lang.String> |
probabilityCol() |
protected static double |
raw2prediction(Vector rawPrediction) |
protected static Vector |
raw2probability(Vector rawPrediction) |
protected Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
static Param<java.lang.String> |
rawPredictionCol() |
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static MLReader<NaiveBayesModel> |
read() |
static void |
save(java.lang.String path) |
static <T> Params |
set(Param<T> param,
T value) |
protected static Params |
set(ParamPair<?> paramPair) |
protected static Params |
set(java.lang.String param,
java.lang.Object value) |
protected static <T> Params |
setDefault(Param<T> param,
T value) |
protected static Params |
setDefault(scala.collection.Seq<ParamPair<?>> paramPairs) |
static M |
setFeaturesCol(java.lang.String value) |
static M |
setParent(Estimator<M> parent) |
static M |
setPredictionCol(java.lang.String value) |
static M |
setProbabilityCol(java.lang.String value) |
static M |
setRawPredictionCol(java.lang.String value) |
static M |
setThresholds(double[] value) |
static DoubleParam |
smoothing() |
DoubleParam |
smoothing()
The smoothing parameter.
|
Matrix |
theta() |
static DoubleArrayParam |
thresholds() |
java.lang.String |
toString() |
static Dataset<Row> |
transform(Dataset<?> dataset) |
static Dataset<Row> |
transform(Dataset<?> dataset,
ParamMap paramMap) |
static Dataset<Row> |
transform(Dataset<?> dataset,
ParamPair<?> firstParamPair,
ParamPair<?>... otherParamPairs) |
static Dataset<Row> |
transform(Dataset<?> dataset,
ParamPair<?> firstParamPair,
scala.collection.Seq<ParamPair<?>> otherParamPairs) |
protected static Dataset<Row> |
transformImpl(Dataset<?> dataset) |
static StructType |
transformSchema(StructType schema) |
protected static StructType |
transformSchema(StructType schema,
boolean logging) |
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
protected static StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
static void |
validateParams() |
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, predictProbability, probability2prediction, raw2prediction, raw2probability, setProbabilityCol, setThresholds, transformpredict, setRawPredictionColfeaturesDataType, setFeaturesCol, setPredictionCol, transformImpl, transformSchematransform, transform, transformtransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitsaveclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParamspublic static MLReader<NaiveBayesModel> read()
public static NaiveBayesModel load(java.lang.String path)
public static Param<?>[] params()
public static void validateParams()
public static java.lang.String explainParam(Param<?> param)
public static java.lang.String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(java.lang.String paramName)
public static Param<java.lang.Object> getParam(java.lang.String paramName)
protected static final Params set(java.lang.String param, java.lang.Object value)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
protected static final <T> T $(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
protected static StructType transformSchema(StructType schema, boolean logging)
public static Dataset<Row> transform(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.Seq<ParamPair<?>> otherParamPairs)
public static Dataset<Row> transform(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
public static Estimator<M> parent()
public static void parent_$eq(Estimator<M> x$1)
public static M setParent(Estimator<M> parent)
public static boolean hasParent()
public static final Param<java.lang.String> labelCol()
public static final java.lang.String getLabelCol()
public static final Param<java.lang.String> featuresCol()
public static final java.lang.String getFeaturesCol()
public static final Param<java.lang.String> predictionCol()
public static final java.lang.String getPredictionCol()
public static M setFeaturesCol(java.lang.String value)
public static M setPredictionCol(java.lang.String value)
protected static DataType featuresDataType()
public static StructType transformSchema(StructType schema)
public static final Param<java.lang.String> rawPredictionCol()
public static final java.lang.String getRawPredictionCol()
public static M setRawPredictionCol(java.lang.String value)
protected static double predict(FeaturesType features)
public static final Param<java.lang.String> probabilityCol()
public static final java.lang.String getProbabilityCol()
public static final DoubleArrayParam thresholds()
public static double[] getThresholds()
protected static StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public static M setProbabilityCol(java.lang.String value)
public static M setThresholds(double[] value)
protected static double raw2prediction(Vector rawPrediction)
protected static Vector predictProbability(FeaturesType features)
protected static double probability2prediction(Vector probability)
public static final DoubleParam smoothing()
public static final double getSmoothing()
public static final Param<java.lang.String> modelType()
public static final java.lang.String getModelType()
public static void save(java.lang.String path)
throws java.io.IOException
java.io.IOExceptionpublic java.lang.String uid()
Identifiableuid in interface Identifiableuid in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>public Vector pi()
public Matrix theta()
public int numFeatures()
numFeatures in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>public int numClasses()
numClasses in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>protected Vector predictRaw(Vector features)
predictRaw in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>protected Vector raw2probabilityInPlace(Vector rawPrediction)
ProbabilisticClassificationModel
This internal method is used to implement transform() and output probabilityCol.
raw2probabilityInPlace in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>rawPrediction - (undocumented)public NaiveBayesModel copy(ParamMap extra)
Paramscopy in interface Paramscopy in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>extra - (undocumented)defaultCopy()public java.lang.String toString()
toString in interface IdentifiabletoString in class ProbabilisticClassificationModel<Vector,NaiveBayesModel>public MLWriter write()
MLWritableMLWriter instance for this ML instance.write in interface MLWritablepublic DoubleParam smoothing()
public double getSmoothing()
public Param<java.lang.String> modelType()
public java.lang.String getModelType()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()