首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >机器学习算法如何保留以前执行的学习?

机器学习算法如何保留以前执行的学习?
EN

Stack Overflow用户
提问于 2019-05-30 01:36:11
回答 1查看 231关注 0票数 4

我正在读一本关于机器学习的书,作者谈到了训练和测试拆分过程中的随机种子,作者说,在这段时间里,机器将看到你的整个数据集。

作者正在使用以下函数来划分Tran和Test split,

代码语言:javascript
复制
def split_train_test(data, test_ratio):
    shuffled_indices = np.random.permutation(len(data))
    test_set_size = int(len(data) * test_ratio)
    test_indices = shuffled_indices[:test_set_size]
    train_indices = shuffled_indices[test_set_size:]
    return data.iloc[train_indices], data.iloc[test_indices]

Usage of the function like this:
>>>train_set, test_set = split_train_test(housing, 0.2)
>>> len(train_set)
16512
>>> len(test_set)
4128

嗯,这是可行的,但它并不完美:如果您再次运行该程序,它将生成一个不同的测试集!随着时间的推移,你(或你的机器学习算法)将看到整个数据集,这是你想要避免的。

Sachin Rastogi:为什么以及如何影响我的模型性能?我知道我的模型精度在每一次运行中都会有所不同,因为训练集总是不同的。我的模型将如何在一段时间内看到整个数据集?

作者还提供了一些解决方案,

一种解决方案是在第一次运行时保存测试集,然后在后续运行中加载它。另一种选择是在调用np.random.permutation()之前设置随机数生成器的种子(例如,np.random.seed(42)),以便它始终生成相同的混洗索引。

但这两个解决方案都将在您下次获取更新的数据集时中断。一种常见的解决方案是使用每个实例的标识符来决定它是否应该进入测试集中(假设实例具有唯一且不可变的标识符)。

Sachin Rastogi:这会是一个很好的训练/测试部门吗?我认为不,Train和Test应该包含跨数据集的元素,以避免来自Train集的任何偏差。

作者举了一个例子,

您可以计算每个实例的标识符的散列,如果散列小于或等于最大散列值的20%,则将该实例放入测试集中。这可确保测试集在多次运行时保持一致,即使您刷新数据集也是如此。

新的测试集将包含20%的新实例,但它将不包含以前在训练集中的任何实例。

Sachin Rastogi:我无法理解这个解决方案。你能帮帮忙吗?

EN

回答 1

Stack Overflow用户

发布于 2020-10-21 06:24:13

对我来说,答案是:

  1. 这里的要点是,在训练模型之前,你最好把你的部分数据(它将构成你的测试集)放在一边。实际上,你想要实现的是能够很好地对看不见的例子进行泛化。通过运行您已经展示的代码,您将随着时间的推移获得不同的测试集;换句话说,您将始终在数据的不同子集上训练模型(可能还会在先前标记为测试数据的数据上进行训练)。这反过来将影响训练,并且-达到极限-将没有什么可以推广。

  1. 如果不添加新数据,这确实是一个满足先前要求(拥有稳定的测试集)的解决方案。

  1. 正如您问题的注释中所说,通过散列每个实例的标识符,您可以确保将旧实例始终分配给相同的子集。

代码语言:javascript
复制
- Instances that were put in the _training set_ before the update of the dataset will remain there (as their hash value won't change - and so their left-most bit - and it will remain higher than 0.2\*max\_hash\_value);
- Instances that were put in the _test set_ before the update of the dataset will remain there (as their hash value won't change and it will remain lower than 0.2\*max\_hash\_value).

更新后的测试集将包含20%的新实例以及与旧测试集关联的所有实例,使其保持稳定。

我还建议在这里查看作者的解释:https://github.com/ageron/handson-ml/issues/71

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56365869

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档