首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何以支持自动梯度的方式围绕其中心旋转PyTorch图像张量?

如何以支持自动梯度的方式围绕其中心旋转PyTorch图像张量?
EN

Stack Overflow用户
提问于 2020-10-05 01:29:50
回答 2查看 7K关注 0票数 5

我想随机旋转一个图像张量(B,C,H,W)围绕它的中心(我想是二维旋转?)。我想避免使用NumPy和Kornia,这样我基本上只需要从torch模块导入。我也没有使用torchvision.transforms,因为我需要它与autograd兼容。本质上,我试图为DeepDream这样的可视化技术创建一个自动分级兼容的torchvision.transforms.RandomRotation()版本(所以我需要尽可能地避免工件)。

代码语言:javascript
复制
import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


# Somehow rotate tensor around it's center
def rotate_tensor(tensor, radians):
    ...
    return rotated_tensor

# Get a random angle within a specified range 
r_degrees = 5
angle_range = list(range(-r_degrees, r_degrees))
n = random.randint(angle_range[0], angle_range[len(angle_range)-1])

# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180


# test_tensor = preprocess_simple('path/to/file', (512,512))
test_tensor = torch.randn(1,3,512,512)


# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor, ang_rad)


# Optionally use this to check rotated image
# deprocess_simple(output_tensor, 'rotated_image.jpg')

我正在尝试完成的一些示例输出:

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-10-05 04:17:44

因此,网格生成器和采样器是Spatial Transformer (JADERBERG,Max等)的子模块。这些子模块是不可训练的,它们允许您应用可学习的和不可学习的空间转换。在这里,我使用这两个子模块,并通过theta使用PyTorch的函数torch.nn.functional.affine_gridtorch.nn.functional.affine_sample (这些函数分别是生成器和采样器的实现)来旋转图像:

代码语言:javascript
复制
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def get_rot_mat(theta):
    theta = torch.tensor(theta)
    return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
                         [torch.sin(theta), torch.cos(theta), 0]])


def rot_img(x, theta, dtype):
    rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
    grid = F.affine_grid(rot_mat, x.size()).type(dtype)
    x = F.grid_sample(x, grid)
    return x


#Test:
dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)

在上面的示例中,假设我们的图像im是一只穿着裙子跳舞的猫:

rotated_im将是一只90度旋转的穿着裙子跳舞的猫:

这是我们用theta公式调用rot_imgnp.pi/4得到的结果:

最好的部分是它是可区分的w.r.t输入,并支持自动评分!万岁!

票数 12
EN

Stack Overflow用户

发布于 2020-10-05 04:22:22

这里有一个pytorch函数:

代码语言:javascript
复制
x = torch.tensor([[0, 1],
            [2, 3]])

x = torch.rot90(x, 1, [0, 1])
代码语言:javascript
复制
>> tensor([[1, 3],
           [0, 2]])

以下是文档:https://pytorch.org/docs/stable/generated/torch.rot90.html

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64197754

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档