我使用MulticlassClassificationEvaluator在PySpark中检索一些指标,如F1-Score或accuracy in a Cross Validation:
cross_result = CrossValidator(estimator=RandomForestClassifier(),
estimatorParamMaps=ParamGridBuilder().build(),
evaluator=MulticlassClassificationEvaluator(metricName='f1'),
numFolds=5,
parallelism=-1)
f1_score = cross_result.avgMetrics[0]现在,我的问题是:如果avgMetrics只有一个值,为什么它是一个列表?它不应该是一个标量值吗?我是否遗漏了此属性的某些内容?
发布于 2021-06-10 21:03:13
根据源代码,我意识到avgMetrics是一个列表,其中包含在ParamGrid中定义的每个参数的度量的所有交叉验证折叠的平均值。所以:
dataset = spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
# Note that there are three values for maxIter: 0, 1 and 5
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
evaluator = MulticlassClassificationEvaluator(metricName='accuracy')
cv = CrossValidator(
estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator,
parallelism=2
)
cvModel = cv.fit(dataset)
cvModel.avgMetrics[0] # Average accuracy for maxIter = 0
cvModel.avgMetrics[1] # Average accuracy for maxIter = 1
cvModel.avgMetrics[2] # Average accuracy for maxIter = 5https://stackoverflow.com/questions/67790646
复制相似问题