我有以下简单的功能:
def f1(y_true, y_pred):
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}根据scikit-learn文档,f1_score的参数可以有以下类型:
y_true:一维数组,或标签指示数组/稀疏矩阵y_pred:一维数组,或标签指示数组/稀疏矩阵输出类型为:
如何在此函数中添加类型提示,以便mypy不会抱怨?
我尝试了以下几种变体:
Array1D = NewType('Array1D', Union[np.ndarray, List[np.float64]])
def f1(y_true: Union[List[float], Array1D], y_pred: Union[List[float], Array1D]) -> Dict[str, Union[List[float], Array1D]]:
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}但这带来了错误。
发布于 2022-08-24 20:59:33
这是我使用的方法,以避免类似的类型问题。它利用了1.20中引入的numpy打字。ArrayLike类型涵盖了List[float],因此无需担心显式地覆盖它。
运行mypyv0.971和numpy v1.23.1,没有问题。
from typing import List, Dict
import numpy as np
import numpy.typing as npt
import sklearn.metrics
def f1(y_true: npt.ArrayLike, y_pred: npt.ArrayLike) -> Dict[str, npt.ArrayLike]:
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
y_true_list: List[float] = [1, 0, 1, 0]
y_pred_list: List[float] = [1, 0, 1, 1]
y_true_np: npt.ArrayLike = np.array(y_true_list)
y_pred_np: npt.ArrayLike = np.array(y_pred_list)
assert f1(y_true_list, y_pred_list) == f1(y_true_np, y_pred_np)发布于 2022-08-23 15:47:20
而不是
Array1D = NewType("Array1D", Union[np.ndarray, List[np.float64]])你可以用
Array1D = Union[np.ndarray, List[np.float64]]https://stackoverflow.com/questions/73418424
复制相似问题