如何在MXNET中创建组合迭代器?例如,给定一个记录(.rec)迭代器,如果我想更改与每个图像对应的标签,则有两个选项: a)创建一个具有相同数据(图像)和新标签的新rec迭代器。b)使用原始.rec迭代器和NDArray迭代器创建多迭代器,以便多迭代器从原始.rec迭代器读取数据(图像)并从NDArray迭代器读取标签。选项(a)很繁琐。关于如何创建这样的多迭代器有什么建议吗?
发布于 2017-08-22 09:29:38
class MultiIter(mx.io.DataIter):
def __init__(self, iter_list):
self.iters = iter_list
self.batch_size = 1000
def next(self):
batches = [i.next() for i in self.iters]
return mx.io.DataBatch(data=[t for t in batches[0].data]+ [t for t in batches[1].data], label= [t for t in batches[0].label] + [t for t in batches[1].label],pad=0)
def reset(self):
for i in self.iters:
i.reset()
@property
def provide_data(self):
return [t for t in self.iters[0].provide_data] + [t for t in self.iters[1].provide_data]
@property
def provide_label(self):
return [t for t in self.iters[0].provide_label] + [t for t in self.iters[1].provide_label]
train = MultiIter([train1,train2])其中train1和train2可以是任意两个DataIter。特别是,train1可以是.rec迭代器,train2可以是NDArray迭代器。如果train1或train2中的任何一个是NDArray迭代器,则使用组合迭代器调用predict方法时需要额外的参数"pad=0“。
MultiIter返回由两个迭代器组合而成的数据列表和标签列表。如果你只需要来自第一个迭代器的数据和来自第二个迭代器的标签,下面的代码就可以工作了。
class MultiIter(mx.io.DataIter):
def __init__(self, iter_list):
self.iters = iter_list
self.batch_size = 1000
def next(self):
batches = [i.next() for i in self.iters]
return mx.io.DataBatch(data=[t for t in batches[0].data], label= [t for t in batches[1].label],pad=0)
def reset(self):
for i in self.iters:
i.reset()
@property
def provide_data(self):
return [t for t in self.iters[0].provide_data]
@property
def provide_label(self):
return [t for t in self.iters[1].provide_label]
train = MultiIter([train1,train2])https://stackoverflow.com/questions/45807556
复制相似问题