首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >NearestNeighbors (Mahalanobis) -太多的争论?

NearestNeighbors (Mahalanobis) -太多的争论?
EN

Stack Overflow用户
提问于 2021-07-31 14:40:03
回答 1查看 278关注 0票数 2

我用的是scikit-learnNearestNeighbors和Mahalanobis距离。

代码语言:javascript
复制
from sklearn.neighbors import NearestNeighbors

nn = NearestNeighbors(
    algorithm='brute', 
    metric='mahalanobis', 
    metric_params={'V': np.cov(d1)}
).fit(d1)

# Indices of 3 d1 points closest to d2 points
indices = nn.kneighbors(d2, 3)[1]

d1d2都是由2元素数字列表组成的numpy数组。例如:

代码语言:javascript
复制
array([[61, 35],
       [61, 20],
       [53, 50],
       ...,
       [63, 70],
       [39, 90],
       [39, 90]])

我在过去几乎使用过这个精确的代码,但是今天我得到了以下错误:

代码语言:javascript
复制
--------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/58/sc_58r5d5wgdg06t7fd2k2sr0000gn/T/ipykernel_77488/409650633.py in <module>
      6 
      7 # Indices of 3 d1 points closest to d2 points
----> 8 indices = nn.kneighbors(d2, 3)[1]
      9 
     10 # Drop duplicates

~/Library/Python/3.8/lib/python/site-packages/sklearn/neighbors/_base.py in kneighbors(self, X, n_neighbors, return_distance)
    703                 kwds = self.effective_metric_params_
    704 
--> 705             chunked_results = list(pairwise_distances_chunked(
    706                 X, self._fit_X, reduce_func=reduce_func,
    707                 metric=self.effective_metric_, n_jobs=n_jobs,

~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances_chunked(X, Y, reduce_func, metric, n_jobs, working_memory, **kwds)
   1621         else:
   1622             X_chunk = X[sl]
-> 1623         D_chunk = pairwise_distances(X_chunk, Y, metric=metric,
   1624                                      n_jobs=n_jobs, **kwds)
   1625         if ((X is Y or Y is None)

~/Library/Python/3.8/lib/python/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     61             extra_args = len(args) - len(all_args)
     62             if extra_args <= 0:
---> 63                 return f(*args, **kwargs)
     64 
     65             # extra_args > 0

~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
   1788         func = partial(distance.cdist, metric=metric, **kwds)
   1789 
-> 1790     return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1791 
   1792 

~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1357 
   1358     if effective_n_jobs(n_jobs) == 1:
-> 1359         return func(X, Y, **kwds)
   1360 
   1361     # enforce a threading backend to prevent data communication overhead

~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in cdist(XA, XB, metric, out, **kwargs)
   2952         if metric_info is not None:
   2953             cdist_fn = metric_info.cdist_func
-> 2954             return cdist_fn(XA, XB, out=out, **kwargs)
   2955         elif mstr.startswith("test_"):
   2956             metric_info = _TEST_METRICS.get(mstr, None)

~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in __call__(self, XA, XB, out, **kwargs)
   1670         # get cdist wrapper
   1671         cdist_fn = getattr(_distance_wrap, f'cdist_{metric_name}_{typ}_wrap')
-> 1672         cdist_fn(XA, XB, dm, **kwargs)
   1673         return dm
   1674 

TypeError: cdist_mahalanobis_double_wrap() takes at most 4 arguments (5 given)

任何关于如何解决这一问题的建议都会受到广泛的赞赏!谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-31 14:52:23

'V'更改为'VI',也许这个帮助:

代码语言:javascript
复制
from sklearn.neighbors import NearestNeighbors
import numpy as np

nn = NearestNeighbors(
    algorithm='brute', 
    metric='mahalanobis', 
    metric_params={'VI': np.cov(d1)}
).fit(d1)

# Indices of 3 d1 points closest to d2 points
indices = nn.kneighbors(d2, 3)[1]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68603099

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档