我已经实现了一个具有自定义损失函数(Ax' - y)^2的3DCNN,其中x‘是来自CNN的3D输出的展平和裁剪向量,y是地面实况,A是一个线性运算符,它采用x并输出a y。因此,我需要一种方法来展平3D输出,并在计算损失之前使用花哨的索引进行裁剪。
这是我尝试过的:这是我试图复制的numpy代码,
def flatten_crop(img_vol, indices, vol_shape, N):
"""
:param img_vol: shape (145, 59, 82, N)
:param indices: shape (396929,)
"""
nVx, nVy, nVz = vol_shape
voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
voxels = voxels[indices, :]
return voxels我尝试使用tf.nd_gather执行相同的操作,但我无法将其推广到任意批处理大小。以下是批处理大小为1(或单个3D输出)的tensorflow代码:
voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82))) # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1)) # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))目前,我的自定义损失函数中有这段代码,我希望能够泛化为任意批处理大小。此外,如果您有实现此功能的替代建议(例如自定义层,而不是与损失函数集成),我洗耳恭听!
谢谢。
发布于 2020-11-30 04:38:04
尝试以下代码:
import tensorflow as tf
y_pred = tf.random.uniform((10, 145, 59, 82))
indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82)) # to flatten and reshape using Fortran-like index order
voxels = tf.gather(voxels, indices, axis=-1)
voxels = tf.transpose(voxels)https://stackoverflow.com/questions/65063685
复制相似问题