首页
学习
活动
专区
圈层
工具
发布

BK树实现
EN

Code Review用户
提问于 2014-10-25 05:28:43
回答 1查看 1.1K关注 0票数 4

我正在开发一个系统,可以搜索类似的图像。这涉及到能够通过编辑距离进行搜索,因此我实现了一个名为BK树的特殊数据结构。一些注记

这里的核心思想是能够通过编辑距离来搜索项目,在这种情况下,可以在离散项之间进行Hamming距离。这个系统的散列存储为64位整数,但从根本上说,它们可以被认为是位字段,并被视为比特域。

无论如何,我有一些测试代码似乎运行得相当好,我非常希望得到一些输入。

代码语言:javascript
复制
from libc.stdint cimport uint64_t

# Compute number of bits that are not common between `a` and `b`.
# return value is a plain integer
cdef uint64_t hamming(uint64_t a, uint64_t b):

    cdef uint64_t x
    cdef int tot

    tot = 0

    x = (a ^ b)
    while x > 0:
        tot += x & 1
        x >>= 1
    return tot

cdef class BkHammingNode(object):

    cdef uint64_t nodeHash
    cdef set nodeData
    cdef dict children

    def __init__(self, nodeHash, nodeData):
        self.nodeData = set((nodeData, ))
        self.children = {}
        self.nodeHash = nodeHash

    # Insert phash `nodeHash` into tree, with the associated data `nodeData`
    cpdef insert(self, uint64_t nodeHash, nodeData):

        # If the current node has the same has as the data we're inserting,
        # add the data to the current node's data set
        if nodeHash == self.nodeHash:
            self.nodeData.add(nodeData)
            return

        # otherwise, calculate the edit distance between the new phash and the current node's hash,
        # and either recursively insert the data, or create a new child node for the phash
        distance = hamming(self.nodeHash, nodeHash)
        if not distance in self.children:
            self.children[distance] = BkHammingNode(nodeHash, nodeData)
        else:
            self.children[distance].insert(nodeHash, nodeData)

    # Remove node with hash `nodeHash` and accompanying data `nodeData` from the tree.
    # Returns list of children that must be re-inserted (or false if no children need to be updated),
    # number of nodes deleted, and number of nodes that were moved as a 3-tuple.
    cpdef remove(self, uint64_t nodeHash, nodeData):
        cdef uint64_t deleted = 0
        cdef uint64_t moved = 0

        # If the node we're on matches the hash we want to delete exactly:
        if nodeHash == self.nodeHash:

            # Remove the node data associated with the hash we want to remove
            self.nodeData.remove(nodeData)

            # If we've emptied out the node of data, return all our children so the parent can
            # graft the children into the tree in the appropriate place
            if not self.nodeData:
                # 1 deleted node, 0 moved nodes, return all children for reinsertion by parent
                # Parent will pop this node, and reinsert all it's children where apropriate
                return list(self), 1, 0

            # node has data remaining, do not do any rebuilding
            return False, 1, 0


        selfDist = hamming(self.nodeHash, nodeHash)

        # Removing is basically searching with a distance of zero, and
        # then doing operations on the search result.
        # As such, scan children where the edit distance between `self.nodeHash` and the target `nodeHash` == 0
        # Rebuild children where needed
        if selfDist in self.children:
            moveChildren, childDeleted, childMoved = self.children[selfDist].remove(nodeHash, nodeData)
            deleted += childDeleted
            moved += childMoved

            # If the child returns children, it means the child no longer contains any unique data, so it
            # needs to be deleted. As such, pop it from the tree, and re-insert all it's children as
            # direct decendents of the current node
            if moveChildren:
                self.children.pop(selfDist)
                for childHash, childData in moveChildren:
                    self.insert(childHash, childData)
                    moved += 1

        return False, deleted, moved

    # Get all child-nodes within an edit distance of `distance` from `baseHash`
    # returns a set containing the data of each matching node, and a integer representing
    # the number of nodes that were touched in the scan.
    # Return value is a 2-tuple
    cpdef getWithinDistance(self, uint64_t baseHash, int distance):
        cdef uint64_t selfDist

        selfDist = hamming(self.nodeHash, baseHash)

        ret = set()

        if selfDist <= distance:
            ret |= set(self.nodeData)

        touched = 1


        for key in self.children.keys():
            if key <= selfDist+distance and key >= selfDist-distance:
                new, tmpTouch = self.children[key].getWithinDistance(baseHash, distance)
                touched += tmpTouch
                ret |= new

        return ret, touched

    def __iter__(self):
        for child in self.children.values():
            for item in child:
                yield item
        for item in self.nodeData:
            yield (self.nodeHash, item)

class BkHammingTree(object):
    root = None

    def __init__(self):
        self.nodes = 0

    def insert(self, nodeHash, nodeData):
        if not self.root:
            self.root = BkHammingNode(nodeHash, nodeData)
        else:
            self.root.insert(nodeHash, nodeData)


        self.nodes += 1

    def remove(self, nodeHash, nodeData):
        if not self.root:
            raise ValueError("No tree built to remove from!")

        rootless, deleted, moved = self.root.remove(nodeHash, nodeData)

        # If the node we're deleting is the root node, we need to handle it properly
        # if it is, overwrite the root node with one of the values returned, and then
        # rebuild the entire tree by reinserting all the nodes
        if rootless:
            print("Tree root deleted! Rebuilding...")
            rootHash, rootData = rootless.pop()
            self.root = BkHammingNode(rootHash, rootData)
            for childHash, childData in rootless:
                self.root.insert(childHash, childData)

        self.nodes -= deleted

        return deleted, moved

    def getWithinDistance(self, baseHash, distance):
        if not self.root:
            return set()

        ret, touched = self.root.getWithinDistance(baseHash, distance)
        print("Touched %s tree nodes, or %1.3f%%" % (touched, touched/self.nodes * 100))
        print("Discovered %s match(es)" % len(ret))
        return ret

    def __iter__(self):
        for value in self.root:
            yield value

这是用cythonized编写的,因为性能原因(纯Python编写得很慢)。现在,它非常快( 4.1M项树在大约20秒内构建),但是我隐约怀疑我遗漏了一些东西,特别是因为我仍然不太了解度量空间。

我认为树的实现是正确的,但如果我遗漏了一些显而易见的东西,我也不会感到惊讶。

EN

回答 1

Code Review用户

发布于 2014-12-09 02:14:10

我还没有真正检查过正确性,但是这里有一些一般的风格要点。我还没有花时间真正理解这里的算法,这些都是表面层面的观点。如果您想对代码正确性进行更深入的分析,我建议您给我一个测试工具(只需要pyximport并运行代码),所以我会了解它是如何使用的。

我对这段代码没有太多的批评;它大部分都很整洁,而且可读性很强。

首先:docstring。而不是:

代码语言:javascript
复制
# Compute number of bits that are not common between `a` and `b`.
# return value is a plain integer
cdef uint64_t hamming(uint64_t a, uint64_t b):

代码语言:javascript
复制
cdef uint64_t hamming(uint64_t a, uint64_t b):
    """
    Compute number of bits that are not common between `a` and `b`.
    return value is a plain integer
    """

如果您实际上只访问3.4,请不要在继承列表中写入(object)。但是,如果您可能也需要2.x兼容性,最好还是保持它。

你有

代码语言:javascript
复制
    cdef set nodeData
    cdef dict children

值得注意的是,这些属性的速度并不比沼泽标准的非类型化属性快得多,但是只要您知道它们是一个很好的数据类型。这就是为什么我希望PyPy会提供更好的速度优势。

而不是set((nodeData, )),只需编写self.nodeData = {nodeData}

remove中,您有cdef uint64_t deleted = 0,但只添加一次。给它一个类型是没有意义的,特别是因为您的返回将它转换为Python类型。moved也存在类似的但较小的担忧。

而不是self.children.pop(selfDist),做del self.children[selfDist]。如果您使用返回值,pop就意味着。

你写:

代码语言:javascript
复制
        ret = set()

        if selfDist <= distance:
            ret |= set(self.nodeData)

在我看来

代码语言:javascript
复制
        ret = set()

        if selfDist <= distance
            ret = set(self.nodeData)

会更好。如果没有,请尝试ret.update(self.nodeData)

您不需要在这里调用.keys()

代码语言:javascript
复制
        for key in self.children.keys():

这是:

代码语言:javascript
复制
            if key <= selfDist+distance and key >= selfDist-distance:

应该只是

代码语言:javascript
复制
            if selfDist-distance <= key <= selfDist+distance:

甚至是

代码语言:javascript
复制
            if abs(key - selfDist) <= distance:

我不知道selfDist是一个uint64_tkey是一个Pythonint,但是如果您只是删除cdef uint64_t,它可能会起作用。

您可能应该避免类变量,如root = None,只需将它们添加到__init__中即可。类变量基本上是全局变量,尽管有一个不变的全局作为回退工作,但它与我所期望的语言使用方式是背道而驰的。

票数 2
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/67895

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档