首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何为已安装模块内的迭代创建进度条

如何为已安装模块内的迭代创建进度条
EN

Stack Overflow用户
提问于 2021-12-31 13:03:41
回答 1查看 194关注 0票数 3

我的目标是为已安装模块内的迭代创建一个进度条。

要为用户定义的函数内的迭代创建进度条,我将tqdm.notebook.tqdm_notebook对象作为可迭代传递:

代码语言:javascript
复制
import time
import numpy as np
from tqdm.notebook import tqdm

def iterate(over):
    for x in over: # creating progress bar for this
        print(x, end='')
        time.sleep(0.5)

xs = np.arange(5)
tqdm_xs = tqdm(xs) # creating tqdm.notebook.tqdm_notebook object

iterate(tqdm_xs) # progress bar, as expected
iterate(xs) # no progress bar

它的作用是:

但是,当我试图对已安装模块内的for循环执行同样的操作时,这将失败。在Astropy的Photutils模块中,有一个for label in labels行(这里),我可以传递标签对象。

可重复的示例(主要基于 -安装了photutils后的工作:pip install photutils):

代码语言:javascript
复制
import photutils.datasets as phdat
import photutils.segmentation as phsegm
import astropy.convolution as conv
import astropy.stats as stats

data = phdat.make_100gaussians_image()
threshold = phsegm.detect_threshold(data, nsigma=2.)
sigma = 1.5
kernel = conv.Gaussian2DKernel(sigma, x_size=3, y_size=3)
kernel.normalize()
segm = phsegm.detect_sources(data, threshold, npixels=5, kernel=kernel)

这样做是可行的:

代码语言:javascript
复制
segm_deblend = phsegm.deblend_sources(data, segm, npixels=5, kernel=kernel,
                                      nlevels=32, contrast=0.001, labels = segm.labels)

试图传递tqdm.notebook.tqdm_notebook对象以创建进度条:

代码语言:javascript
复制
tqdm_segm_labels = tqdm(segm.labels)
segm_deblend = phsegm.deblend_sources(data, segm, npixels=5, kernel=kernel,
                                    nlevels=32, contrast=0.001, labels = tqdm_segm_labels)

我得到了一个AttributeError: 'int' object has no attribute '_comparable'。完整回溯:

代码语言:javascript
复制
0%
0/92 [00:00<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-8-d101466650ae> in <module>()
      1 tqdm_segm_labels = tqdm(segm.labels)
      2 segm_deblend = phsegm.deblend_sources(data, segm, npixels=5, kernel=kernel,
----> 3                                     nlevels=32, contrast=0.001, labels = tqdm_segm_labels)

4 frames
/usr/local/lib/python3.7/dist-packages/astropy/utils/decorators.py in wrapper(*args, **kwargs)
    534                     warnings.warn(message, warning_type, stacklevel=2)
    535 
--> 536             return function(*args, **kwargs)
    537 
    538         return wrapper

/usr/local/lib/python3.7/dist-packages/photutils/segmentation/deblend.py in deblend_sources(data, segment_img, npixels, kernel, labels, nlevels, contrast, mode, connectivity, relabel)
    112         labels = segment_img.labels
    113     labels = np.atleast_1d(labels)
--> 114     segment_img.check_labels(labels)
    115 
    116     if kernel is not None:

/usr/local/lib/python3.7/dist-packages/photutils/segmentation/core.py in check_labels(self, labels)
    355 
    356         # check for positive label numbers
--> 357         idx = np.where(labels <= 0)[0]
    358         if idx.size > 0:
    359             bad_labels.update(labels[idx])

/usr/local/lib/python3.7/dist-packages/tqdm/utils.py in __le__(self, other)
     70 
     71     def __le__(self, other):
---> 72         return (self < other) or (self == other)
     73 
     74     def __eq__(self, other):

/usr/local/lib/python3.7/dist-packages/tqdm/utils.py in __lt__(self, other)
     67     """Assumes child has self._comparable attr/@property"""
     68     def __lt__(self, other):
---> 69         return self._comparable < other._comparable
     70 
     71     def __le__(self, other):

AttributeError: 'int' object has no attribute '_comparable'

一个解决方法就是修改Photutils并在其中使用tqdm (我在这叉子上做的,它可以工作),但是这看起来有点过分,我希望有一个更简单的方法来做到这一点。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-01 21:07:36

当然,通常情况下,没有办法直接修改一些您没有自己编写的现有代码(不管它是否“安装”不是问题)。

如果您认为它是真正有用或感兴趣的,您可以建议一个补丁来允许这个函数使用,例如,在每个循环上调用一个回调函数。如果它通常是一个缓慢的函数,它可能是有用的(我确实注意到在实现中有一些东西可以被修改以加速它,但这是另一回事)。

当然,您可以找到一些聪明的黑客来让它在这个特定的情况下工作,但是考虑到它是专门针对这个函数的实现细节设计的,它将是脆弱的。我找到了一些可能性。

最简单的似乎是这个愚蠢的伎俩:

创建一个ndarray子类(我称之为tqdm_array),该子类在Python中迭代时,在tqdm进度栏上返回一个迭代器,它封装数组本身:

代码语言:javascript
复制
class tqdm_array(np.ndarray):
    def __iter__(self):
        return iter(tqdm.tqdm(np.asarray(self)))

然后,当准备调用deblend_sources时,请将标签包装如下:

代码语言:javascript
复制
labels = np.array(segm_image.labels).view(tqdm_array)

然后把它传给deblend_sources(..., labels=labels, ...)

这将起作用,因为即使labels被NumPy代码迭代,它也将使用内部C代码直接在数组缓冲区上迭代(例如,对于像labels <= 0这样的操作。在大多数情况下,它不会调用Python __iter__方法,尽管可能会有异常.

但是,当遇到像for label in labels:这样的Python循环(在这个函数中碰巧只有一个)时,您将得到进度条。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70542600

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档