首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >kNN与Python

kNN与Python
EN

Code Review用户
提问于 2018-01-30 21:04:52
回答 1查看 473关注 0票数 4

我正在编写一个k最近的实现来解决多类分类问题。

代码语言:javascript
复制
import heapq
import logging

import numpy as np

from scipy import spatial

logging.basicConfig()

class KNN(object):

    similarities = {
        1: lambda a, b: np.linalg.norm(a-b),
        2: lambda a, b: spatial.distance.cosine(a, b),
    }

    def __init__(self, k, similarity_func, loglevel=logging.DEBUG):
        self.k = k
        self.logger = logging.getLogger(type(self).__name__)
        self.logger.setLevel(loglevel)
        if similarity_func not in KNN.similarities:
            raise ValueError("Illegal similarity value {0}. Legal values are {1}".format(similarity_func, sorted(KNN.similarities.keys())))
        self.similarity_func = KNN.similarities[similarity_func]

    def train(self, X, y):
        self.training_X = X
        self.training_y = y
        self.num_classes = len(np.unique(y))
        self.logger.debug("There are %s classes", self.num_classes)
        return self

    def probs(self, X):
        class_probs = []
        for i, e in enumerate(X, 1):
            votes = np.zeros((self.num_classes,))
            self.logger.debug("Votes: %s", votes)
            if i % 100 == 0:
                self.logger.info("Example %s", i)
            distance = [(self.similarity_func(e, x), y) for x, y in zip(self.training_X, self.training_y)]
            for (_, label) in heapq.nsmallest(self.k, distance, lambda t: t[0]):
                votes[label] += 1
            class_probs.append(normalize(votes))
        return class_probs

    def predict(self, X):
        return np.argmax(self.probs(X))

我发现这个实现的predict是缓慢的™,并认为可以用numpy中的向量化操作来加速它,但我对numpy矢量化技术相当缺乏经验。

有人对我可以从predict获得的性能提升有一些建议吗?

EN

回答 1

Code Review用户

发布于 2018-02-01 01:24:54

我要发布一篇优化文章:

欧几里德距离不需要完全计算!

因为我只将它们用于排名,所以没有必要使用平方根。因此,可以使用以下方法:

代码语言:javascript
复制
def squared_euclidean(x, y):
    dist = np.array(x) - np.array(y)
    return np.dot(dist, dist)
票数 1
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/186391

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档