首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在磁共振图像上正确地使用from_tensor_slices?

如何在磁共振图像上正确地使用from_tensor_slices?
EN

Stack Overflow用户
提问于 2021-11-26 13:12:04
回答 1查看 198关注 0票数 1

我正在处理MRI图像,我想使用from_tensor_slices对路径进行预处理,但我不知道如何正确地使用它。下面是我的代码、问题消息和数据集的链接。

首先,我重新整理我的数据。484张图片和484张标签

代码语言:javascript
复制
image_data_path = './drive/MyDrive/Brain Tumour/Task01_BrainTumour/imagesTr/'
label_data_path = './drive/MyDrive/Brain Tumour/Task01_BrainTumour/labelsTr/'

image_paths = [image_data_path + name 
               for name in os.listdir(image_data_path) 
               if not name.startswith(".")]

label_paths = [label_data_path + name
               for name in os.listdir(label_data_path)
               if not name.startswith(".")]

image_paths = sorted(image_paths)
label_paths = sorted(label_paths)

然后是加载1个示例的函数(我使用nibabel加载nii文件)

代码语言:javascript
复制
def load_one_sample(image_path, label_path):

  image = nib.load(image_path).get_fdata()
  image = tf.convert_to_tensor(image, dtype = 'float32')
  label = nib.load(label_path).get_fdata()
  label = tf.convert_to_tensor(label, dtype = 'uint8')

  return image, label

接下来,我尝试使用from_tensor_slices

代码语言:javascript
复制
image_filenames = tf.constant(image_paths)
label_filenames = tf.constant(label_paths)

dataset = tf.data.Dataset.from_tensor_slices((image_filenames, label_filenames))

all_data = dataset.map(load_one_sample)

错误来了:TypeError: stat: path should be string, bytes, os.PathLike or integer, not Tensor

有什么不对的,我该怎么解决呢?

数据链:https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2 (任务1-脑瘤)

如果你需要更多的信息,请告诉我。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-26 16:12:07

nib.load不是TensorFlow函数。

如果您想使用tf.data管道中任何不是TensorFlow函数的东西,那么您必须使用tf.py_function包装它。

代码:

代码语言:javascript
复制
image_data_path = 'Task01_BrainTumour/imagesTr/'
label_data_path = 'Task01_BrainTumour/labelsTr/'

image_paths = [image_data_path + name 
               for name in os.listdir(image_data_path) 
               if not name.startswith(".")]
label_paths = [label_data_path + name
               for name in os.listdir(label_data_path)
               if not name.startswith(".")]

image_paths = sorted(image_paths)
label_paths = sorted(label_paths)

def load_one_sample(image_path, label_path):
  image = nib.load(image_path.numpy().decode()).get_fdata()
  image = tf.convert_to_tensor(image, dtype = 'float32')
  label = nib.load(label_path.numpy().decode()).get_fdata()
  label = tf.convert_to_tensor(label, dtype = 'uint8')
  return image, label

def wrapper_load(img_path, label_path):
  img, label = tf.py_function(func = load_one_sample, inp = [img_path, label_path], Tout = [tf.float32, tf.uint8])
  return img, label

dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths)).map(wrapper_load)

错误不是由于from_tensor_slices函数造成的,而是由于nibs.load需要一个字符串而得到一个张量而产生的。

然而,一个更好的方法是创建to记录并使用它们来训练模型。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70124971

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档