首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >MXNET多迭代器(结合使用.rec迭代器和NDArray迭代器)

MXNET多迭代器(结合使用.rec迭代器和NDArray迭代器)
EN

Stack Overflow用户
提问于 2017-08-22 09:17:10
回答 1查看 379关注 0票数 2

如何在MXNET中创建组合迭代器?例如,给定一个记录(.rec)迭代器,如果我想更改与每个图像对应的标签,则有两个选项: a)创建一个具有相同数据(图像)和新标签的新rec迭代器。b)使用原始.rec迭代器和NDArray迭代器创建多迭代器,以便多迭代器从原始.rec迭代器读取数据(图像)并从NDArray迭代器读取标签。选项(a)很繁琐。关于如何创建这样的多迭代器有什么建议吗?

EN

回答 1

Stack Overflow用户

发布于 2017-08-22 09:29:38

代码语言:javascript
复制
class MultiIter(mx.io.DataIter):  
    def __init__(self, iter_list):  
        self.iters = iter_list   
        self.batch_size = 1000  
    def next(self):  
        batches = [i.next() for i in self.iters]  
        return mx.io.DataBatch(data=[t for t in batches[0].data]+ [t for t in batches[1].data], label= [t for t in batches[0].label] + [t for t in batches[1].label],pad=0)  
    def reset(self):  
        for i in self.iters:  
            i.reset()  
    @property  
    def provide_data(self):  
        return [t for t in self.iters[0].provide_data] + [t for t in self.iters[1].provide_data] 
    @property  
    def provide_label(self):  
        return [t for t in self.iters[0].provide_label] + [t for t in self.iters[1].provide_label]

train = MultiIter([train1,train2])

其中train1和train2可以是任意两个DataIter。特别是,train1可以是.rec迭代器,train2可以是NDArray迭代器。如果train1或train2中的任何一个是NDArray迭代器,则使用组合迭代器调用predict方法时需要额外的参数"pad=0“。

MultiIter返回由两个迭代器组合而成的数据列表和标签列表。如果你只需要来自第一个迭代器的数据和来自第二个迭代器的标签,下面的代码就可以工作了。

代码语言:javascript
复制
class MultiIter(mx.io.DataIter):  
    def __init__(self, iter_list):  
        self.iters = iter_list   
        self.batch_size = 1000  
    def next(self):  
        batches = [i.next() for i in self.iters]  
        return mx.io.DataBatch(data=[t for t in batches[0].data], label= [t for t in batches[1].label],pad=0)  
    def reset(self):  
        for i in self.iters:  
            i.reset()  
    @property  
    def provide_data(self):  
        return [t for t in self.iters[0].provide_data] 
    @property  
    def provide_label(self):  
        return [t for t in self.iters[1].provide_label] 

train = MultiIter([train1,train2])
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45807556

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档