首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在DistilBERT中进行交叉验证

如何在DistilBERT中进行交叉验证
EN

Stack Overflow用户
提问于 2021-08-16 12:13:28
回答 1查看 122关注 0票数 0

我已经创建了一个简单的模型来使用DistilBERT进行文本分类。问题是我不知道如何在训练时进行交叉验证。下面提供了我的代码实现。

有没有人可以帮我在培训的同时实现交叉验证?

提前谢谢你。

代码语言:javascript
复制
    #Split into Train-Test-Validation    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.10, random_state = 0)
    X_val, X_test, y_val, y_test = train_test_split(X_test,y_test, test_size=0.10, random_state=42)
    
    
    #Encoding text for train data
    train_encoded = tokenizer(X_train, truncation=True, padding=True, return_tensors="tf")
    train_data = tf.data.Dataset.from_tensor_slices((dict(train_encoded), y_train))
    
    #Encoding text for validation data
    val_encoded = tokenizer(X_val, truncation=True, padding=True, return_tensors="tf")
    val_data = tf.data.Dataset.from_tensor_slices((dict(val_encoded), y_val))
    
    #Encoding text for testing data
    test_data = tf.data.Dataset.from_tensor_slices((dict(test_encoded), y_test))
    test_encoded = tokenizer(X_test, truncation=True, padding=True, return_tensors="tf")
    
    #Load distil bert model
    model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
    model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])
    model.fit(train_data.batch(16), epochs=10, batch_size=16)
EN

回答 1

Stack Overflow用户

发布于 2021-08-16 14:48:35

我建议使用K-折叠式验证作为交叉评估策略!

代码语言:javascript
复制
kf = KFold(n_splits=10, random_state=99, shuffle=True)
for train_index, test_index in kf.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    X_val, X_test, y_val, y_test = train_test_split(X_test,y_test, test_size=0.10, random_state=42)
    #Encoding text for train data
    train_encoded = tokenizer(X_train, truncation=True, padding=True, return_tensors="tf")
    train_data = tf.data.Dataset.from_tensor_slices((dict(train_encoded), y_train))
    
    #Encoding text for validation data
    val_encoded = tokenizer(X_val, truncation=True, padding=True, return_tensors="tf")
    val_data = tf.data.Dataset.from_tensor_slices((dict(val_encoded), y_val))
    
    #Encoding text for testing data
    test_data = tf.data.Dataset.from_tensor_slices((dict(test_encoded), y_test))
    test_encoded = tokenizer(X_test, truncation=True, padding=True, return_tensors="tf")
    
    #Load distil bert model
    model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
    model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])
    model.fit(train_data.batch(16), epochs=10, batch_size=16)
    #Get your results and perform analysis

作为另一种方式,你可以用sklearn-api支持包装你的模型,然后享受交叉验证和sklearn提供的几十个其他实用程序!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68802609

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档