我正在使用spark mlib,并使用Logistic回归模型进行分类。我按照这个链接: https://spark.apache.org/docs/2.1.0/ml-classification-regression.html#logistic-regression
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// Load training data
Dataset<Row> training = spark.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8);
// Fit the model
LogisticRegressionModel lrModel = lr.fit(training);
// Print the coefficients and intercept for logistic regression
System.out.println("Coefficients: "
+ lrModel.coefficients() + " Intercept: " + lrModel.intercept());
// We can also use the multinomial family for binary classification
LogisticRegression mlr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
.setFamily("multinomial");
// Fit the model
LogisticRegressionModel mlrModel = mlr.fit(training);
如果我将.csv作为输入,我不确定这个模型如何识别标签和功能?任何人都可以解释一下吗?
答案 0 :(得分:2)
因为你加载libsvm fromat数据,它由标签index1组成:value1 index2:value2 ...... 如果使用.csv,则必须明确指定参数。
答案 1 :(得分:2)
最后我能够修复它,我需要使用VectorAssembler
或StringIndexer
转换器,并且我有etInputCol
,setOutputCol
方法,它提供了设置方法标签和功能。
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"Lead ID"})
.setOutputCol("features");
sparkSession.read().option("header", true).option("inferSchema","true").csv("Book.csv");
dataset = new StringIndexer().setInputCol("Status").setOutputCol("label").fit(dataset).transform(dataset);