首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在火炬后训练中缩小模型尺寸

如何在火炬后训练中缩小模型尺寸
EN

Stack Overflow用户
提问于 2022-05-12 07:05:41
回答 1查看 1.2K关注 0票数 0

我已经创建了一个pytorch模型,我想缩小模型的大小。定义模型架构:-

代码语言:javascript
复制
import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time
import numpy as np
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import time
import codecs
import pickle
import torch
from torch.autograd import Variable
import torch.nn.utils.prune as prune

from config import Config
from loader import *
from utils import *
from model import BiLSTM_CRF

START_TAG = '<START>'
STOP_TAG = '<STOP>'


def init_embedding(input_embedding):
   """
   Initialize embedding
   """
   bias = np.sqrt(3.0 / input_embedding.size(1))
   nn.init.uniform(input_embedding, -bias, bias)


def init_linear(input_linear):
   """
    Initialize linear transformation
   """
   bias = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1)))
   nn.init.uniform(input_linear.weight, -bias, bias)
   if input_linear.bias is not None:
       input_linear.bias.data.zero_()


def init_lstm(input_lstm):
   """
   Initialize lstm
   """
   for ind in range(0, input_lstm.num_layers):
       weight = eval('input_lstm.weight_ih_l' + str(ind))
   

       bias = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
    

       nn.init.uniform(weight, -bias, bias)

       weight = eval('input_lstm.weight_hh_l' + str(ind))
    

       bias = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
   

       nn.init.uniform(weight, -bias, bias)
   if input_lstm.bidirectional:
       for ind in range(0, input_lstm.num_layers):
           weight = eval('input_lstm.weight_ih_l' + str(ind) + '_reverse')
           bias = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
           nn.init.uniform(weight, -bias, bias)
           weight = eval('input_lstm.weight_hh_l' + str(ind) + '_reverse')
           bias = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
           nn.init.uniform(weight, -bias, bias)

   if input_lstm.bias:
       for ind in range(0, input_lstm.num_layers):
           weight = eval('input_lstm.bias_ih_l' + str(ind))
           weight.data.zero_()
           weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
           weight = eval('input_lstm.bias_hh_l' + str(ind))
           weight.data.zero_()
           weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
    if input_lstm.bidirectional:
        for ind in range(0, input_lstm.num_layers):
            weight = eval('input_lstm.bias_ih_l' + str(ind) + '_reverse')
            weight.data.zero_()
            weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
            weight = eval('input_lstm.bias_hh_l' + str(ind) + '_reverse')
            weight.data.zero_()
            weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1


 def to_scalar(var):
     return var.view(-1).data.tolist()[0]


 def argmax(vec):
    _, idx = torch.max(vec, 1)
    return to_scalar(idx)


 def log_sum_exp(vec):
   # vec 2D: 1 * tagset_size
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
     torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))


 class BiLSTM_CRF(nn.Module):

   def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim, char_lstm_dim=25,
             char_to_ix=None, pre_word_embeds=None, char_embedding_dim=25, use_gpu=False,
             n_cap=None, cap_embedding_dim=None, use_crf=True, char_mode='CNN'):
    super(BiLSTM_CRF, self).__init__()
      self.use_gpu = use_gpu
      self.embedding_dim = embedding_dim  #100
      self.hidden_dim = hidden_dim        #200
      self.vocab_size = vocab_size
      self.tag_to_ix = tag_to_ix
      self.n_cap = n_cap
      self.cap_embedding_dim = cap_embedding_dim
      self.use_crf = use_crf
      self.tagset_size = len(tag_to_ix)
      self.out_channels = char_lstm_dim       #25
      self.char_mode = char_mode

      print('char_mode: %s, out_channels: %d, hidden_dim: %d, ' % (char_mode, char_lstm_dim, hidden_dim))

      if self.n_cap and self.cap_embedding_dim:
          self.cap_embeds = nn.Embedding(self.n_cap, self.cap_embedding_dim)
        # print("self.cap_embeds.weight------",self.cap_embeds.weight)
          init_embedding(self.cap_embeds.weight)

      if char_embedding_dim is not None:
          self.char_lstm_dim = char_lstm_dim
          self.char_embeds = nn.Embedding(len(char_to_ix), char_embedding_dim)
        # print("self.char_embeds.weight-------", self.char_embeds.weight)
          init_embedding(self.char_embeds.weight)
          if self.char_mode == 'LSTM':
              self.char_lstm = nn.LSTM(char_embedding_dim, char_lstm_dim, num_layers=1, bidirectional=True)
            init_lstm(self.char_lstm)
        if self.char_mode == 'CNN':
            self.char_cnn3 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, kernel_size=(3, char_embedding_dim), padding=(2,0))

    self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
    if pre_word_embeds is not None:
        self.pre_word_embeds = True
        self.word_embeds.weight = nn.Parameter(torch.FloatTensor(pre_word_embeds))
    else:
        self.pre_word_embeds = False

    self.dropout = nn.Dropout(0.5)
    if self.n_cap and self.cap_embedding_dim:
        if self.char_mode == 'LSTM':
            self.lstm = nn.LSTM(embedding_dim+char_lstm_dim*2+cap_embedding_dim, hidden_dim, bidirectional=True)
        if self.char_mode == 'CNN':
            self.lstm = nn.LSTM(embedding_dim+self.out_channels+cap_embedding_dim, hidden_dim, bidirectional=True)
    else:
        if self.char_mode == 'LSTM':
            self.lstm = nn.LSTM(embedding_dim+char_lstm_dim*2, hidden_dim, bidirectional=True)
        if self.char_mode == 'CNN':
            self.lstm = nn.LSTM(embedding_dim+self.out_channels, hidden_dim, bidirectional=True)
    init_lstm(self.lstm)
    self.hw_trans = nn.Linear(self.out_channels, self.out_channels)
    self.hw_gate = nn.Linear(self.out_channels, self.out_channels)
    self.h2_h1 = nn.Linear(hidden_dim*2, hidden_dim)
    self.tanh = nn.Tanh()
    self.hidden2tag = nn.Linear(hidden_dim*2, self.tagset_size)
    init_linear(self.h2_h1)
    init_linear(self.hidden2tag)
    init_linear(self.hw_gate)
    init_linear(self.hw_trans)

    if self.use_crf:
        self.transitions = nn.Parameter(
            torch.zeros(self.tagset_size, self.tagset_size))
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

def _score_sentence(self, feats, tags):
    # tags is ground_truth, a list of ints, length is len(sentence)
    # feats is a 2D tensor, len(sentence) * tagset_size
    r = torch.LongTensor(range(feats.size()[0]))
    if self.use_gpu:
        r = r.cuda()
        pad_start_tags = torch.cat([torch.cuda.LongTensor([self.tag_to_ix[START_TAG]]), tags])
        pad_stop_tags = torch.cat([tags, torch.cuda.LongTensor([self.tag_to_ix[STOP_TAG]])])
    else:
        pad_start_tags = torch.cat([torch.LongTensor([self.tag_to_ix[START_TAG]]), tags])
        pad_stop_tags = torch.cat([tags, torch.LongTensor([self.tag_to_ix[STOP_TAG]])])

    score = torch.sum(self.transitions[pad_stop_tags, pad_start_tags]) + torch.sum(feats[r, tags])

    return score

def _get_lstm_features(self, sentence, chars2, caps, chars2_length, d):

    if self.char_mode == 'LSTM':
        # self.char_lstm_hidden = self.init_lstm_hidden(dim=self.char_lstm_dim, bidirection=True, batchsize=chars2.size(0))
        chars_embeds = self.char_embeds(chars2).transpose(0, 1)
        packed = torch.nn.utils.rnn.pack_padded_sequence(chars_embeds, chars2_length)
        lstm_out, _ = self.char_lstm(packed)
        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
        outputs = outputs.transpose(0, 1)
        chars_embeds_temp = Variable(torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2)))))
        if self.use_gpu:
            chars_embeds_temp = chars_embeds_temp.cuda()
        for i, index in enumerate(output_lengths):
            chars_embeds_temp[i] = torch.cat((outputs[i, index-1, :self.char_lstm_dim], outputs[i, 0, self.char_lstm_dim:]))
        chars_embeds = chars_embeds_temp.clone()
        for i in range(chars_embeds.size(0)):
            chars_embeds[d[i]] = chars_embeds_temp[i]

    if self.char_mode == 'CNN':
        chars_embeds = self.char_embeds(chars2).unsqueeze(1)
        chars_cnn_out3 = self.char_cnn3(chars_embeds)
        chars_embeds = nn.functional.max_pool2d(chars_cnn_out3, kernel_size=(chars_cnn_out3.size(2), 1)).view(chars_cnn_out3.size(0), self.out_channels)

    # t = self.hw_gate(chars_embeds)
    # g = nn.functional.sigmoid(t)
    # h = nn.functional.relu(self.hw_trans(chars_embeds))
    # chars_embeds = g * h + (1 - g) * chars_embeds

    embeds = self.word_embeds(sentence)
    if self.n_cap and self.cap_embedding_dim:
        cap_embedding = self.cap_embeds(caps)

    if self.n_cap and self.cap_embedding_dim:
        embeds = torch.cat((embeds, chars_embeds, cap_embedding), 1)
    else:
        embeds = torch.cat((embeds, chars_embeds), 1)

    embeds = embeds.unsqueeze(1)
    embeds = self.dropout(embeds)
    lstm_out, _ = self.lstm(embeds)
    lstm_out = lstm_out.view(len(sentence), self.hidden_dim*2)
    lstm_out = self.dropout(lstm_out)
    lstm_feats = self.hidden2tag(lstm_out)
    return lstm_feats

def _forward_alg(self, feats):
    # calculate in log domain
    # feats is len(sentence) * tagset_size
    # initialize alpha with a Tensor with values all equal to -10000.
    init_alphas = torch.Tensor(1, self.tagset_size).fill_(-10000.)
    init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
    forward_var = autograd.Variable(init_alphas)
    if self.use_gpu:
        forward_var = forward_var.cuda()
    for feat in feats:
        emit_score = feat.view(-1, 1)
        tag_var = forward_var + self.transitions + emit_score
        max_tag_var, _ = torch.max(tag_var, dim=1)
        tag_var = tag_var - max_tag_var.view(-1, 1)
        forward_var = max_tag_var + torch.log(torch.sum(torch.exp(tag_var), dim=1)).view(1, -1) # ).view(1, -1)
    terminal_var = (forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]).view(1, -1)
    alpha = log_sum_exp(terminal_var)
    # Z(x)
    return alpha

def viterbi_decode(self, feats):
    backpointers = []
    # analogous to forward
    init_vvars = torch.Tensor(1, self.tagset_size).fill_(-10000.)
    init_vvars[0][self.tag_to_ix[START_TAG]] = 0
    forward_var = Variable(init_vvars)
    if self.use_gpu:
        forward_var = forward_var.cuda()
    for feat in feats:
        next_tag_var = forward_var.view(1, -1).expand(self.tagset_size, self.tagset_size) + self.transitions
        _, bptrs_t = torch.max(next_tag_var, dim=1)
        bptrs_t = bptrs_t.squeeze().data.cpu().numpy()
        next_tag_var = next_tag_var.data.cpu().numpy()
        viterbivars_t = next_tag_var[range(len(bptrs_t)), bptrs_t]
        viterbivars_t = Variable(torch.FloatTensor(viterbivars_t))
        if self.use_gpu:
            viterbivars_t = viterbivars_t.cuda()
        forward_var = viterbivars_t + feat
        backpointers.append(bptrs_t)

    terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
    terminal_var.data[self.tag_to_ix[STOP_TAG]] = -10000.
    terminal_var.data[self.tag_to_ix[START_TAG]] = -10000.
    best_tag_id = argmax(terminal_var.unsqueeze(0))
    path_score = terminal_var[best_tag_id]
    best_path = [best_tag_id]
    for bptrs_t in reversed(backpointers):
        best_tag_id = bptrs_t[best_tag_id]
        best_path.append(best_tag_id)
    start = best_path.pop()
    assert start == self.tag_to_ix[START_TAG]
    best_path.reverse()
    return path_score, best_path

def neg_log_likelihood(self, sentence, tags, chars2, caps, chars2_length, d):
    # sentence, tags is a list of ints
    # features is a 2D tensor, len(sentence) * self.tagset_size
    feats = self._get_lstm_features(sentence, chars2, caps, chars2_length, d)

    if self.use_crf:
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score
    else:
        tags = Variable(tags)
        scores = nn.functional.cross_entropy(feats, tags)
        return scores


def forward(self, sentence, chars, caps, chars2_length, d):
    feats = self._get_lstm_features(sentence, chars, caps, chars2_length, d)
    # viterbi to get tag_seq
    if self.use_crf:
        score, tag_seq = self.viterbi_decode(feats)
    else:
        score, tag_seq = torch.max(feats, 1)
        tag_seq = list(tag_seq.cpu().data)

    return score, tag_seq

创建模型实例:-

代码语言:javascript
复制
 model_fp32 = BiLSTM_CRF(vocab_size=len(word_to_id),
               tag_to_ix=tag_to_id,
               embedding_dim=parameters['word_dim'],
               hidden_dim=parameters['word_lstm_dim'],
               use_gpu=parameters['use_gpu'],
               char_to_ix=char_to_id,
               pre_word_embeds=word_embeds,
               use_crf=parameters['crf'],
               char_mode=parameters['char_mode'])

应用量化

代码语言:javascript
复制
 model_int8 = torch.quantization.quantize_dynamic(
     model_fp32,  # the original model
    {nn.LSTM,nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)                    

检查量化结果:

代码语言:javascript
复制
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

比较尺寸

代码语言:javascript
复制
f=print_size_of_model(model_fp32,"model_fp32")
q=print_size_of_model(model_int8,"model_int8")
print("{0:.2f} times smaller".format(f/q))

结果

型号: model_fp32大小(KB):806494.996

型号: model_int8大小(KB):804532.412倍

有什么办法可以大大缩小型号吗?

EN

回答 1

Stack Overflow用户

发布于 2022-05-15 21:44:31

根据问题的结果部分和大约200万的vocab_size,量化属性word_embeds似乎是合理的。预计仅量化此模块将导致内存占用的重量显着下降。根据文档,不支持动态量化(在上面的代码片段中用于nn.Linear和nn.LSTM )的nn.Embedding(类型word_embeds),但是静态量化可以处理这一问题。默认的qconfig (在某些pytorch实例中使用)似乎不适用于nn.Embedding,但在问题讨论中有一个如何量化nn.Embedding的提示。培训后:

代码语言:javascript
复制
from torch.quantization.qconfig import float_qparams_weight_only_qconfig

model_fp32.word_embeds.qconfig = float_qparams_weight_only_qconfig
torch.quantization.prepare(model_fp32, inplace=True)
torch.quantization.convert(model_fp32, inplace=True)

在此之后,word_embeds in model_fp32将被量化为torhc.quint8

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72211362

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档