在StackOverflow和其他地方有很多声明说nth_element是O(n),并且它通常是用Introselect:元素实现的。
我想知道如何才能做到这一点。我看着维基百科对Introselect的解释,这让我更加困惑。算法如何在QSort和中间值中位数之间切换?
我在这里找到了Introsort文件:http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.5196&rep=rep1&type=pdf,但上面写着:
在本文中,我们集中讨论排序问题,并在后面的一节中简要地回到选择问题。
我试着阅读STL本身,以了解nth_element是如何实现的,但很快就会变得毛茸茸的。
有人能给我看一下Introselect是如何实现的伪代码吗?或者更好的是,除了STL之外的实际C++代码(当然:)
发布于 2015-03-19 14:05:03
你问了两个问题,一个是名义上的
nth_element是如何实现的?
你已经回答了:
在StackOverflow和其他地方有很多声明说nth_element是O(n),并且它通常是用Introselect实现的。
我也可以通过查看stdlib实现来确认这一点。(稍后将对此进行详细介绍)
还有一个你不明白答案的人:
算法如何在QSort和中间值中位数之间切换?
让我们看看我从stdlib中提取的伪代码:
nth_element(first, nth, last)
{
if (first == last || nth == last)
return;
introselect(first, nth, last, log2(last - first) * 2);
}
introselect(first, nth, last, depth_limit)
{
while (last - first > 3)
{
if (depth_limit == 0)
{
// [NOTE by editor] This should be median-of-medians instead.
// [NOTE by editor] See Azmisov's comment below
heap_select(first, nth + 1, last);
// Place the nth largest element in its final position.
iter_swap(first, nth);
return;
}
--depth_limit;
cut = unguarded_partition_pivot(first, last);
if (cut <= nth)
first = cut;
else
last = cut;
}
insertion_sort(first, last);
}在不详细介绍引用的函数heap_select和unguarded_partition_pivot的情况下,我们可以清楚地看到,nth_element给出了自选2 * log2(size)细分步骤(在最好的情况下,它是quickselect所需的两倍),直到heap_select启动并永久解决问题为止。
发布于 2015-03-19 13:42:17
免责声明:我不知道在任何标准库中如何实现std::nth_element。
如果您知道快速排序是如何工作的,您可以很容易地修改它,以完成此算法所需的操作。快速排序的基本思想是,在每一步中,将数组划分为两个部分,这样所有小于枢轴的元素都在左子数组中,所有等于或大于枢轴的元素都在右子数组中。(对快速排序的修改称为三元快速排序,将创建第三个子数组,其所有元素都等于枢轴。然后右边的子数组只包含严格大于枢轴的条目。)然后,通过递归地对左右子数组进行排序,从而进行快速排序。
如果您只想将n元素移到适当的位置,而不是递归到两个子数组中,那么您可以在每一步中判断您需要下降到左数组还是右子数组。(您知道这一点,因为排序数组中的第n个元素有索引n,因此它变成了比较索引的问题。)因此,除非您的快速排序遭受最坏的情况退化,否则在每一步中,您都会将剩余数组的大小大致减半。(你再也不看另一个子数组了。)因此,平均而言,在每一步中都要处理以下长度的数组:
每一步处理的数组的长度都是线性的。(循环一次,并根据每个元素与枢轴的比较,决定每个元素应该执行哪个子数组。)
您可以看到,在Θ(log(N))步骤之后,我们最终将到达一个单例数组并完成。如果你把N (1 + 1/2 + 1/4 +…)加起来,你会得到2N,或者,在一般情况下,因为我们不能希望枢轴总是正中位,这是Θ( N )的一个数量级。
发布于 2015-03-19 13:42:55
STL的代码(我认为是3.3版本)是这样的:
template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last, _Tp*) {
while (__last - __first > 3) {
_RandomAccessIter __cut =
__unguarded_partition(__first, __last,
_Tp(__median(*__first,
*(__first + (__last - __first)/2),
*(__last - 1))));
if (__cut <= __nth)
__first = __cut;
else
__last = __cut;
}
__insertion_sort(__first, __last);
}让我们简化一下:
template <class Iter, class T>
void nth_element(Iter first, Iter nth, Iter last) {
while (last - first > 3) {
Iter cut =
unguarded_partition(first, last,
T(median(*first,
*(first + (last - first)/2),
*(last - 1))));
if (cut <= nth)
first = cut;
else
last = cut;
}
insertion_sort(first, last);
}我在这里所做的是删除双下划线和_Uppercase内容,这只是为了保护代码不受用户合法定义为宏的东西的影响。我还删除了最后一个参数,该参数仅用于模板类型的推断,并将迭代器类型重命名为简洁。
正如您现在应该看到的,它反复地对范围进行分区,直到剩下的范围中保留了少于4个元素,然后简单地对其排序。
那为什么是O(n)?首先,由于三个元素的最大值,最终的排序是O(1)。现在,剩下的是重复的分区。分区本身就是O(n)。但是,在这里,每一步都将需要在下一步中接触的元素数减半,所以如果加起来的话,O(n) + O(n/2) + O(n/4) + O(n/8)小于O(2n)。由于O(2n) = O(n),平均来说,线性复杂度是很高的。
https://stackoverflow.com/questions/29145520
复制相似问题