首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Python语言中使用MXNet预训练的图像分类模型

在Python语言中使用MXNet预训练的图像分类模型
EN

Stack Overflow用户
提问于 2016-02-22 18:51:25
回答 1查看 710关注 0票数 1

我试图相信在Python3中为R描述的所有东西,但到目前为止,我还没有更进一步。

这里描述了R语言的教程:http://mxnet.readthedocs.org/en/latest/R-package/classifyRealImageWithPretrainedModel.html

我如何在Python中做同样的事情?使用以下模型:https://github.com/dmlc/mxnet-model-gallery/blob/master/imagenet-1k-inception-bn.md

致以良好的问候,凯文

EN

回答 1

Stack Overflow用户

发布于 2018-03-08 03:12:22

目前,在mxnet中使用Python可以做的事情比使用R多得多。我使用的是Gluon API,它使编写代码变得更加简单,并且允许加载预先训练好的模型。

您在本教程中使用的模型是Inception model。所有可用的预训练模型的列表可以在here上找到。

本教程中的其余操作是数据规范化和增强。您可以对新数据进行标准化,方法类似于API页面上的标准化:

代码语言:javascript
复制
image = image/255
normalized = mx.image.color_normalize(image,
                                      mean=mx.nd.array([0.485, 0.456, 0.406]),
                                      std=mx.nd.array([0.229, 0.224, 0.225]))

可能的增强列表可从here获得。

以下是为您提供的可运行示例。我只做了一次增强,如果你想做更多,你可以给mx.image.CreateAugmenter添加更多的参数:

代码语言:javascript
复制
%matplotlib inline
import mxnet as mx
from mxnet.gluon.model_zoo import vision
from matplotlib.pyplot import imshow

def plot_mx_array(array, clip=False):
    """
    Array expected to be 3 (channels) x heigh x width, and values are floats between 0 and 255.
    """
    assert array.shape[2] == 3, "RGB Channel should be last"
    if clip:
        array = array.clip(0,255)
    else:
        assert array.min().asscalar() >= 0, "Value in array is less than 0: found " + str(array.min().asscalar())
        assert array.max().asscalar() <= 255, "Value in array is greater than 255: found " + str(array.max().asscalar())
    array = array/255
    np_array = array.asnumpy()
    imshow(np_array)


inception_model = vision.inception_v3(pretrained=True)

with open("/Volumes/Unix/workspace/MxNet/2018-02-20T19-43-45/types_of_data_augmentation/output_4_0.png", 'rb') as open_file:
    encoded_image = open_file.read()
    example_image = mx.image.imdecode(encoded_image)
    example_image = example_image.astype("float32")
    plot_mx_array(example_image)


augmenters = mx.image.CreateAugmenter(data_shape=(1, 100, 100))

for augementer in augmenters:
    example_image = augementer(example_image)

plot_mx_array(example_image)

example_image = example_image / 255
normalized_image = mx.image.color_normalize(example_image,
                                      mean=mx.nd.array([0.485, 0.456, 0.406]),
                                      std=mx.nd.array([0.229, 0.224, 0.225]))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/35551692

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档