首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于Kotlin的PyTorch与DJL实现的推理差异

基于Kotlin的PyTorch与DJL实现的推理差异
EN

Stack Overflow用户
提问于 2021-04-12 21:57:56
回答 1查看 125关注 0票数 0

我在PyTorch数据集上训练了一个17朵花模型,并通过PyTorch的跟踪将其转换为JIT模型。我已经测试了PyTorch模型和JIT转换模型的推理输出,结果是等价的。这使我相信,我的DJL框架的实现存在一个问题。

当我试图利用转换后的DJL模型进行推理时会出现一个问题,这对于DJL来说是必要的。我没有得到100%的匹配,这是我所期望的。

djl.ai的Kotlin实现非常简单,本质上遵循这里的指令。

下面是Kotlin代码的净化版本:

代码语言:javascript
复制
@Throws(IOException::class, ModelException::class, TranslateException::class)
internal fun main(args: Array<String>) {
    val artifactId = "ai.djl.localmodelzoo:torchscript_17flowers"
    val pipeline = Pipeline()
    pipeline.add(CenterCrop(224, 224))
        .add(Resize(224, 224))
        .add(ToTensor())
        .add(Normalize(floatArrayOf(0.485f, 0.456f, 0.406f), floatArrayOf(0.229f, 0.224f, 0.225f)))
    val translator = ImageClassificationTranslator.builder()
        .setPipeline(pipeline)
        .optSynsetArtifactName("synset.txt")
        .optApplySoftmax(true)
        .build();

    System.setProperty("ai.djl.repository.zoo.location","build/pytorch_models/torchscript_17flowers")

    val criteria = Criteria.builder()
        .setTypes(Image::class.java, Classifications::class.java) // defines input and output data type
        .optTranslator(translator)
        .optArtifactId(artifactId) // defines which model to load
        .optProgress(ProgressBar())
        .build()

    val model = ModelZoo.loadModel(criteria)

    // single image test
    var img = ImageFactory.getInstance().fromUrl("https://image.jpg");
    img.getWrappedImage()

    val predictor: Predictor<Image, Classifications> = model.newPredictor()
    val classifications: Classifications = predictor.predict(img)
    val best = classifications.best<Classifications.Classification>()
}

我的问题不是让事情运行那么多,而是让推理结果匹配。我的理解是,它们应该匹配,而且Kotlin应该工作得很好,因为DJL是为Java工作的。我很好奇,对于这个遇到的问题,是否有任何想法。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-18 18:58:39

这种差异很可能来自于图像预处理:

代码语言:javascript
复制
pipeline.add(CenterCrop(224, 224))
        .add(Resize(224, 224))
        .add(ToTensor())
        .add(Normalize(floatArrayOf(0.485f, 0.456f, 0.406f), floatArrayOf(0.229f, 0.224f, 0.225f)))

许多PyTorch的简历模型,他们不做一个中心作物。为了获得与python相同的结果,您必须确保处理它们的方式与python代码相同。

DJL图像操作与openCV操作符一致,如果在python中使用枕头,您可能会在结果中看到一些差异。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67066218

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档