首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow中的奇特索引

tensorflow中的奇特索引
EN

Stack Overflow用户
提问于 2020-11-30 02:32:00
回答 1查看 83关注 0票数 1

我已经实现了一个具有自定义损失函数(Ax' - y)^2的3DCNN,其中x‘是来自CNN的3D输出的展平和裁剪向量,y是地面实况,A是一个线性运算符,它采用x并输出a y。因此,我需要一种方法来展平3D输出,并在计算损失之前使用花哨的索引进行裁剪。

这是我尝试过的:这是我试图复制的numpy代码,

代码语言:javascript
复制
def flatten_crop(img_vol, indices, vol_shape, N):
    """
    :param img_vol: shape (145, 59, 82, N)
    :param indices: shape (396929,)
    """
    nVx, nVy, nVz = vol_shape
    voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
    voxels = voxels[indices, :]
    return voxels

我尝试使用tf.nd_gather执行相同的操作,但我无法将其推广到任意批处理大小。以下是批处理大小为1(或单个3D输出)的tensorflow代码:

代码语言:javascript
复制
voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82)))   # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1))    # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))

目前,我的自定义损失函数中有这段代码,我希望能够泛化为任意批处理大小。此外,如果您有实现此功能的替代建议(例如自定义层,而不是与损失函数集成),我洗耳恭听!

谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-11-30 04:38:04

尝试以下代码:

代码语言:javascript
复制
import tensorflow as tf
y_pred = tf.random.uniform((10, 145, 59, 82))
indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82))   # to flatten and reshape using Fortran-like index order
voxels = tf.gather(voxels, indices, axis=-1)
voxels = tf.transpose(voxels)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65063685

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档