首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Wav2Vec2ForCTC没有从模式初始化,传递'sampling_rate‘参数

Wav2Vec2ForCTC没有从模式初始化,传递'sampling_rate‘参数
EN

Stack Overflow用户
提问于 2022-03-06 22:16:19
回答 1查看 456关注 0票数 0

我的代码有点工作(它在监听和捕捉我的声音),但却在警告我!!我在代码中一直存在这样的问题:

代码语言:javascript
复制
1)
UserWarning: positional arguments and argument "destination" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
warnings.warn(
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

听我说..。

代码语言:javascript
复制
2)
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.

代码语言:javascript
复制
import torch
import speech_recognition as sr
import io
from pydub import AudioSegment
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
recognizer = sr.Recognizer()

while True:
    audio = recognizer.listen(source)
    data = io.BytesIO(audio.get_wav_data())
    clip = AudioSegment.from_file(data)
    tensor = torch.FloatTensor(clip.get_array_of_samples())

    inputs = tokenizer(tensor, sample_rate=16000, return_tensors="pt", padding="longest").input_values
    logits = model(inputs).logits
    tokens = torch.argmax(logits, dim=-1)
    text = tokenizer.batch_decode(tokens)

    print(str(text).lower())
EN

回答 1

Stack Overflow用户

发布于 2022-06-21 06:25:10

代码语言:javascript
复制
import soundfile as sf
import torch       
ds, samplerate = sf.read(audio_file)
input_values = speech_tokenizer(ds  , return_tensors="pt", sampling_rate = samplerate, padding="longest"  )

将sampling_rate传递给语音标记器以解决sampling_rate警告,并确保您的"sampling_rate":16000

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71374635

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档