首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >当insert语句包含'from transformers‘时无法访问tf.Dataset,例如"from transformers glue_processors“

当insert语句包含'from transformers‘时无法访问tf.Dataset,例如"from transformers glue_processors“
EN

Stack Overflow用户
提问于 2020-11-02 13:26:37
回答 1查看 22关注 0票数 0

当在Win10 (python run_tf_glue_Test.py)的命令提示符下运行以下代码(run_tf_glue_Test.py)时,代码在运行几分钟后停止,没有错误。'for line in datasets:‘语句之前的信息显示在控制台中,因此我假定生成了tf.Dataset。

修改并删除了原始run_tf_glue.py中的大部分主要函数(位于trnsformers/eamaples/ text -clasification/由git clone https://github.com/huggingface/transformers'下载,然后由pip安装),以简化和澄清出现此问题的原因。对原始代码进行了修改,在转换为字典格式的glue_convert_examples_to_features后将文本文件馈送到tf.Dataset函数。

这个问题发生在生成tf.Datase的模块中,但是相同的代码在协作式中工作得很好。

我是否在本地环境中遗漏了什么?

python环境:3.7.6

代码语言:javascript
复制
# coding=utf-8
""" Fine-tuning the library models for sequence classification."""


import logging,sys
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional

import numpy as np,csv
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import glue_processors


def gen_dataset(subf):
    print('** in**')
    #tf.enable_eager_execution()
    logging.info('version(%s)',tf.__version__)
    categories = ['0', '1'] 
    category_to_id = {
        category: index for index, category in enumerate(categories)
    }
    origin='c:/tools/python37/transformers/'+subf+'.tsv'
    origin=origin.replace('validation','dev')
    logging.info('*** origin(%s)',origin)
    with open(origin,'r',encoding='utf-8') as fin:
        reader = csv.reader(fin,delimiter='\t')
        
        rows = [{
            'idx': index,
            'sentence': row[0],
            'label': category_to_id[row[1]],
        } for index, row in enumerate(reader) if row[1] in categories]
    datasets = tf.data.Dataset.from_generator(
        lambda: rows,{'idx': tf.int64, 'sentence': tf.string, 'label': tf.int64}
    )
    logging.info('** records(%d) type(%s)**',len(rows),type(datasets))
    for line in datasets:
        logging.info('*** line(%s)',line)
        print('** lines:',line)
        sys.exit(7)
    return(datasets,len(rows))

logger = logging.getLogger(__name__)

def main():
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    gen_dataset('data/train')
if __name__ == "__main__":
    main()
EN

回答 1

Stack Overflow用户

发布于 2020-11-03 09:25:55

原因是GPU使用导致的,因此在'import os‘语句后插入'os.environ"CUDA_VISIBLE_DEVICES"="-1"’,以指定使用cpu,然后问题才会消失。

很抱歉占用了你的时间。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64640207

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档