首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用ML5训练的模型预测tfjs时的精度降低

用ML5训练的模型预测tfjs时的精度降低
EN

Stack Overflow用户
提问于 2019-08-29 06:56:22
回答 1查看 360关注 0票数 2

我使用的是tfjs 1.0.0 on Google Chrome \ 76.0.3809.132 (正式构建)(64位)

在我的项目中,我使用ML5来训练图像分类模型。我使用了特征抽取器进行迁移学习。我使用mobilenet_v1_0.25作为基本模型。我想把它集成起来,这样它就可以执行chrome扩展的预测。我不得不使用tfjs,因为我发现ML5不是从扩展的后台页面运行的。我使用tfjs加载由ML5训练的模型,然后开始预测。然而,与ML5本身相同的模型预测相比,tfjs的预测精度很低。

我尝试通过废弃ML5 ML5 Feature 源代码,在tfjs中复制来自tfjs的预测,但是从tfjs进行预测时,预测精度仍然大大降低。

我首先加载移动网络和定制模型,以建立一个联合模型。

代码语言:javascript
复制
load() {
    console.log("ML Data loading..");
    // ! ==========================================
    // ! This is a work around and will only work for default version and alpha values that were used while training model.
    this.mobilenet = await tf.loadLayersModel("https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json");
    const layer = this.mobilenet.getLayer('conv_pw_13_relu');
    // ! ==========================================

    this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });
    this.customModel = await tf.loadLayersModel("./model.json");
    this.model.add(this.mobilenetFeatures);
    this.model.add(this.customModel);
}

然后,我将图像传递给一个在预测后得到顶级类的函数。

let result = this.getTopKClasses(this.predict(image), 5)

代码语言:javascript
复制
getTopKClasses(logits, topK) {
    const predictions = logits;
    const values = predictions.dataSync();
    predictions.dispose();
    let predictionList = [];
    for (let i = 0; i < values.length; i++) {
        predictionList.push({ value: values[i], index: i });
    }
    predictionList = predictionList
        .sort((a, b) => {
            return b.value - a.value;
        })
        .slice(0, topK);
    console.log(predictionList);
    let site = predictionList[0];
    let result = { type: 'custom', site: IMAGENET_CLASSES[site.index] }
    console.log('ML Result: Site: %s, Probability: %i%', result.site, (site.value * 100));
    if (site.value > ML_THRESHOLD) {
        return result;
    } else {
        return null;
    }
}

predict(image) {
    const preprocessed = this.imgToTensor(image, [224, 224])
    console.log(preprocessed);
    var result = this.model.predict(preprocessed);
    return result;
}

帮助者职能:

代码语言:javascript
复制
imgToTensor(input, size = null) {
    return tf.tidy(() => {
        let img = tf.browser.fromPixels(input);
        if (size) {
            img = tf.image.resizeBilinear(img, size);
        }
        const croppedImage = this.cropImage(img);
        const batchedImage = croppedImage.expandDims(0);
        return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
    });
}

cropImage(img) {
    const size = Math.min(img.shape[0], img.shape[1]);
    const centerHeight = img.shape[0] / 2;
    const beginHeight = centerHeight - (size / 2);
    const centerWidth = img.shape[1] / 2;
    const beginWidth = centerWidth - (size / 2);
    return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
};
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-09-09 09:54:18

为了从使用转移学习训练的模型(即在另一个预先训练的模型上进行训练)进行预测,首先需要从基本模型预测,然后通过将基模型预测张量传递到自定义模型预测输入来预测。

代码语言:javascript
复制
async load(options = {}) {

    // Loading the layer from the base model.
    this.mobilenet = await tf.loadLayersModel(`${BASE_URL}${this.config.version}_${this.config.alpha}_${IMAGE_SIZE}/model.json`);
    const layer = this.mobilenet.getLayer(this.config.layer);

    //Converting the base-model layer to a model.
    this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });

    // Loading the custom model that was trained by us.
    this.customModel = await tf.loadLayersModel(CUSTOM_MODEL_FILE_URL);
}

现在要从这些模型中预测:

代码语言:javascript
复制
predict(image) {
    // Converting image to tensor
    const preprocessed = this.imgToTensor(image, [224, 224])

    // * Make predictions about the image firstly, from the Mobilenet (base) Model.
    const embeddings = this.mobilenetFeatures.predict(preprocessed);

    // * Filter predictions from Mobilenet Model using custom trained Model.
    const result = this.customModel.predict(embeddings);

    return result;
}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57704653

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档