首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将tensorflow.js模型权值转换为放火张量,然后返回?

如何将tensorflow.js模型权值转换为放火张量,然后返回?
EN

Stack Overflow用户
提问于 2020-12-18 02:15:40
回答 1查看 924关注 0票数 3

我使用的是ml5.js,一个张力流am的包装器。我想在浏览器中训练一个神经网络,下载权重,在pyTorch中将它们作为张量处理,然后将它们加载到浏览器的tensorflowjs模型中。如何在这些格式( tfjs <-> pytorch )之间进行转换

浏览器模型有一个save()函数,它生成三个文件。特定于ml5.js (json)的元数据文件、描述模型体系结构(json)的拓扑文件和二进制权重文件(bin)。

代码语言:javascript
复制
// Browser
model.save()
代码语言:javascript
复制
// HTTP/Download
model_meta.json   (needed by ml5.js)
model.json        (needed by tfjs)
model.weights.bin (needed by tfjs)
代码语言:javascript
复制
# python backend
import json

with open('model.weights.bin', 'rb') as weights_file:
    with open('model.json', 'rb') as model_file:
        weights = weights_file.read()
        model = json.loads(model_file.read())
        ####
        pytorch_tensor = convert2tensor(weights, model) # whats in this function?
        ####
        # Do some processing in pytorch

        ####
        new_weights_bin = convert2bin(pytorch_tensor, model) # and in this?
        ####

这里是示例javascript代码来生成和加载浏览器中的3个文件。若要加载,请在对话框中同时选择所有3个文件。如果它们是正确的,弹出窗口将显示一个示例预测。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-19 05:54:28

我找到了一种将tfjs model.weights.bin转换为numpy的ndarrays的方法。将numpy数组转换为py呼救state_dict是很简单的,它是张量及其名称的字典。

理论

首先,应该理解模型的tfjs表示。model.json描述了模型。在python中,它可以作为字典来阅读。它有以下键:

  1. 模型体系结构被描述为密钥modelTopology下的另一个json/字典。
  2. 它还在键weightsManifest下有一个json/字典,它描述了包装在相应的model.weights.bin文件中的每个权重的类型/形状/位置。另外,权重清单允许多个.bin文件存储权重。

Tensorflow.js有一个附带的python包tensorflowjs,它附带了用于在tf.js二进制和numpy数组格式之间设置朗读权重的实用函数。

每个权重文件都被读取为“组”。组是包含键namedata的字典列表,它们引用权重名称和包含权重的numpy数组。还有可选的其他键。

代码语言:javascript
复制
group = [{'name': weight_name, 'data': np.ndarray}, ...]   # 1 *.bin file

应用程序

安装张力流。不幸的是,它还将安装tensorflow。

代码语言:javascript
复制
pip install tensorflowjs

使用这些功能。请注意,为了方便起见,我更改了签名。

代码语言:javascript
复制
from typing import Dict, ByteString
import torch
from tensorflowjs.read_weights import decode_weights
from tensorflowjs.write_weights import write_weights

def convert2tensor(weights: ByteString, model: Dict) -> Dict[str, torch.Tensor]:
    manifest = model['weightsManifest']
    # If flatten=False, returns a list of groups equal to the number of .bin files.
    # Use flatten=True to convert to a single group
    group = decode_weights(manifest, weights, flatten=True)
    # Convert dicts in tfjs group format into pytorch's state_dict format:
    # {name: str, data: ndarray} -> {name: tensor}
    state_dict = {d['name']: torch.from_numpy(d['data']) for d in group}
    return state_dict

def convert2bin(state_dict: Dict[str: np.ndarray], model: Dict, directory='./'):
    # convert state_dict to groups (list of 1 group)
    groups = [[{'name': key, 'data': value} for key, value in state_dict.items()]]
    # this library function will write to .bin file[s], but you can read it back
    # or change the function internals my copying them from source
    write_weights(groups, directory, write_manifest=False)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65350949

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档