首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将类型提示添加到scikit学习函数中?

如何将类型提示添加到scikit学习函数中?
EN

Stack Overflow用户
提问于 2022-08-19 14:37:53
回答 2查看 201关注 0票数 2

我有以下简单的功能:

代码语言:javascript
复制
def f1(y_true, y_pred):
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

根据scikit-learn文档,f1_score的参数可以有以下类型:

  • y_true:一维数组,或标签指示数组/稀疏矩阵
  • y_pred:一维数组,或标签指示数组/稀疏矩阵

输出类型为:

  • 浮点数或浮点数数组,shape = n_unique_labels

如何在此函数中添加类型提示,以便mypy不会抱怨?

我尝试了以下几种变体:

代码语言:javascript
复制
Array1D = NewType('Array1D', Union[np.ndarray, List[np.float64]])

def f1(y_true: Union[List[float], Array1D], y_pred: Union[List[float], Array1D]) -> Dict[str, Union[List[float], Array1D]]:
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

但这带来了错误。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-08-24 20:59:33

这是我使用的方法,以避免类似的类型问题。它利用了1.20中引入的numpy打字ArrayLike类型涵盖了List[float],因此无需担心显式地覆盖它。

运行mypyv0.971和numpy v1.23.1,没有问题。

代码语言:javascript
复制
from typing import List, Dict
import numpy as np
import numpy.typing as npt
import sklearn.metrics


def f1(y_true: npt.ArrayLike, y_pred: npt.ArrayLike) -> Dict[str, npt.ArrayLike]:
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

y_true_list: List[float] = [1, 0, 1, 0]
y_pred_list: List[float] = [1, 0, 1, 1]
y_true_np: npt.ArrayLike = np.array(y_true_list)
y_pred_np: npt.ArrayLike = np.array(y_pred_list)

assert f1(y_true_list, y_pred_list) == f1(y_true_np, y_pred_np)
票数 2
EN

Stack Overflow用户

发布于 2022-08-23 15:47:20

而不是

代码语言:javascript
复制
Array1D = NewType("Array1D", Union[np.ndarray, List[np.float64]])

你可以用

代码语言:javascript
复制
Array1D = Union[np.ndarray, List[np.float64]]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73418424

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档