首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >KNN K近邻: train_test_split和knn.kneighbors

KNN K近邻: train_test_split和knn.kneighbors
EN

Stack Overflow用户
提问于 2020-03-05 10:45:02
回答 1查看 1.5K关注 0票数 0

我向那些在库学习中使用函数train_test_splitKNN (K-最近邻)算法(特别是knn.kneighbors函数)的经验丰富的人提出了以下一般性问题。

假设您有一个非常标准的熊猫dataframe (没有索引)中的样本,由以下列组成:

  • 某人的姓名
  • 特征X1
  • 特征X2
  • 特征X3
  • 目标Y1

因此,我们假设有100行通用数据。

当调用函数train_test_split时,以参数的形式传递具有特性的列(如df['X1‘、'X2’、‘X3’)和带有目标的列(如df'Y1'),作为回报,可以得到4个变量X_test、X_train、y_test、y_train以随机方式分裂。

好的。到目前一切尚好。

假设在此之后,使用一个算法KNN对测试数据进行预测。因此,您可以发出如下命令:

代码语言:javascript
复制
knn=KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train,y_train)
y_pred=knn.predict(X_test)

好的。很好。y_pred包含预测。

现在,这里有一个问题,您想知道谁是X_train数据点的“邻居”,这些数据点使得预测成为可能。

为此,有一个名为Knn.Kneighbors的函数,它返回k点的距离和坐标数组(如X1、X2、X3数组),这些点被认为是X_train集合的每个数据点的“邻居”。

代码语言:javascript
复制
neighbors=knn.kneighbors(X=X_test)

问题是:如何将邻居变量中返回的坐标表示的数据与原始数据集联系起来,以了解这些坐标属于谁(=>列‘姓名’)?

我的意思是:对于原始数据集的每一行,都有一个与之相关的“人的名字”。你不能把这个传递给X_train或X_test。因此,我是否可以将knn.kneighbors函数返回的邻居数组(现在是随机混合的,不引用原始数据)重新链接到原始数据集?

有什么简单的方法可以重新连接吗?最后,我想知道X_train数据中点的邻居名是什么,而不仅仅是函数knn.kneighbors返回的匿名坐标数组。

否则我就不得不在原始数据中把邻居们循环起来,知道他们属于谁.但我希望不要这么做。

感谢所有的提前。安德里亚

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-05 12:39:27

如果要设置knn.kneighbors(X=X_test),则函数return_distance=False的输出更具可读性。在这种情况下,结果数组中的每一行表示X_test中每个点(行)的最近邻居的X_test数的索引。

注意,这些索引对应于训练集X_train中的索引。如果您想将它们映射回原始数据框架中的Name列,我认为您必须使用熊猫索引。

我希望下面的例子是有意义的。

创建数据集:

代码语言:javascript
复制
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

np.random.seed(42)  # for repoducibility
df = pd.DataFrame(np.random.randn(20, 3),
                  columns=["X1", "X2", "X3"])
df["Name"] = df.index.values * 100  # assume the names are just pandas index * 100
Y = np.random.randint(0, 2, 20)  # targets

print(df)

    X1           X2         X3          Name
0   0.496714    -0.138264   0.647689    0
1   1.523030    -0.234153   -0.234137   100
2   1.579213    0.767435    -0.469474   200
3   0.542560    -0.463418   -0.465730   300
4   0.241962    -1.913280   -1.724918   400
5   -0.562288   -1.012831   0.314247    500
6   -0.908024   -1.412304   1.465649    600
7   -0.225776   0.067528    -1.424748   700
8   -0.544383   0.110923    -1.150994   800
9   0.375698    -0.600639   -0.291694   900
10  -0.601707   1.852278    -0.013497   1000
11  -1.057711   0.822545    -1.220844   1100
12  0.208864    -1.959670   -1.328186   1200
13  0.196861    0.738467    0.171368    1300
14  -0.115648   -0.301104   -1.478522   1400
15  -0.719844   -0.460639   1.057122    1500
16  0.343618    -1.763040   0.324084    1600
17  -0.385082   -0.676922   0.611676    1700
18  1.031000    0.931280    -0.839218   1800
19  -0.309212   0.331263    0.975545    1900

列车测试是否分裂:

代码语言:javascript
复制
X_train, X_test, y_train, y_test = train_test_split(df.iloc[:, :3],
                                                    Y,
                                                    random_state=24  # for reproducibility
                                                   )

请注意每个数据帧的索引:

代码语言:javascript
复制
print(X_train)

          X1        X2        X3
8   0.375698 -0.600639 -0.291694
14 -0.115648 -0.301104 -1.478522
16  0.343618 -1.763040  0.324084
7  -0.225776  0.067528 -1.424748
10 -0.601707  1.852278 -0.013497
12  0.208864 -1.959670 -1.328186
19 -0.309212  0.331263  0.975545
18  1.031000  0.931280 -0.839218
15 -0.719844 -0.460639  1.057122
11 -1.057711  0.822545 -1.220844
4   0.241962 -1.913280 -1.724918
1   1.523030 -0.234153 -0.234137
0   0.496714 -0.138264  0.647689
3   0.542560 -0.463418 -0.465730
2   1.579213  0.767435 -0.469474

print(X_test)

          X1        X2        X3
13  0.196861  0.738467  0.171368
6  -0.908024 -1.412304  1.465649
17 -0.385082 -0.676922  0.611676
5  -0.562288 -1.012831  0.314247
9   0.375698 -0.600639 -0.291694

既然我们已经通过设置随机种子来确保可重现性,那么让我们做一个改变,帮助我们理解knn.kneighbors(X=X_test)的结果。我将X_train中的第一行设置为与X_test的最后一行相同。如果这两个点是相同的,那么当我们查询X_test.loc[[9]] (或X_test.iloc[4, :])时,它应该作为最接近的点返回自己。

注意,索引8的第一行已被更改,等于X_test的最后一行。

代码语言:javascript
复制
X_train.loc[8]  = X_test.loc[9]
print(X_train)

          X1        X2        X3
8   0.375698 -0.600639 -0.291694
14 -0.115648 -0.301104 -1.478522
16  0.343618 -1.763040  0.324084
7  -0.225776  0.067528 -1.424748
10 -0.601707  1.852278 -0.013497
12  0.208864 -1.959670 -1.328186
19 -0.309212  0.331263  0.975545
18  1.031000  0.931280 -0.839218
15 -0.719844 -0.460639  1.057122
11 -1.057711  0.822545 -1.220844
4   0.241962 -1.913280 -1.724918
1   1.523030 -0.234153 -0.234137
0   0.496714 -0.138264  0.647689
3   0.542560 -0.463418 -0.465730
2   1.579213  0.767435 -0.469474

培训KNN模型:

代码语言:javascript
复制
knn = KNeighborsClassifier(n_neighbors=2)
knn.fit(X_train, y_train)

为了使事情变得简单,让我们得到一个点的近邻(同样的解释适用于多个点)。

获取特定点X_test.loc[[9]] = [ 0.375698 -0.600639 -0.291694]的两个最近邻居,我们在上面使用它来更改X_train):

代码语言:javascript
复制
nn_indices = knn.kneighbors(X=X_test.loc[[9]], return_distance=False)
print(nn_indices)
[[ 0 13]]

它们是:

代码语言:javascript
复制
print(X_train.iloc[np.squeeze(nn_indices)])

         X1        X2        X3
8  0.375698 -0.600639 -0.291694  < - Same point in X_train
3  0.542560 -0.463418 -0.465730  < - Second closest point in X_train

这意味着013X_train中的行最接近于点[ 0.375698 -0.600639 -0.291694]

为了将它们映射到名称,您可以使用:

代码语言:javascript
复制
print(df["Name"][np.squeeze(X_train.index.values[nn_indices])])

8    800
3    300
Name: Name, dtype: int64

如果您没有设置return_distance=False,您将注意到第一个距离值为零(到一个恰好是它本身的点的距离是零)。

代码语言:javascript
复制
nn_distances, nn_indices = knn.kneighbors(X=X_test.loc[[9]])
print(nn_distances)

[[0.         0.27741858]] 

您还可以使用n_neighbors参数来获取更接近的邻居。默认情况下,它将运行在拟合模型时使用的值。

编辑:

对于整个X_test,您可以这样做:

代码语言:javascript
复制
nn_indices = knn.kneighbors(X=X_test, return_distance=False)
pd.DataFrame(nn_indices).applymap(lambda x: df["Name"][X_train.index.values[x]])

    0       1
0   1900    0
1   1500    1600
2   1500    0
3   1500    1600
4   800     300
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60543562

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档