- A+
所属分类:教程
最近在学习使用 SHAP 算法解释 BERT 模型的输出结果,然而在从 Huggingface 上导入模型和数据集的过程中出现了网络连接相关的错误,本文用于记录错误类型和解决错误的方法。
1 代码示例
SHAP 官方展示的代码如下:
import datasets
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import BertTokenizer, BertForSequenceClassification
import shap
# load a BERT sentiment analysis model
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained(
"distilbert-base-uncased"
)
model = transformers.DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
).cuda()
# define a prediction function
def f(x):
tv = torch.tensor(
[
tokenizer.encode(v, padding="max_length", max_length=512, truncation=True) for v in x
]
).cuda()
outputs = model(tv)
outputs = outputs[0].detach().cpu().numpy()
scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
val = sp.special.logit(scores[:, 1]) # use one vs rest logit units
return val
# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer)
# explain the model's predictions on IMDB reviews
imdb_train = datasets.load_dataset("imdb")["train"]
shap_values = explainer(imdb_train[:10], fixed_context=1, batch_size=2)
shap.plots.bar(shap_values.abs.sum(0))
2 报错详情
在安装好所有相关依赖库后,运行上述代码出现了如下错误:
大致意思是找不到 dataset_info.json 文件,也就是说这个文件没有被成功下载,于是尝试打开科学上网。在开启科学上网后,继续运行代码出现下列关于 SSLError 的错误:
说明网络连接仍然存在问题,无法访问到 Huggingface。在网上搜罗了各种方法后,终于找到了相应的解决方案,亲测有效。
3 解决方案
首先找到目前使用的深度学习环境中的 request.py 文件,例如在我的环境中该文件的路径为:
D:\Anaconda\envs\test\Lib\urllib\request.py
然后通过搜索 proxyServer 关键字定位到下图代码处:
将 else 块中的代码修改为下列代码:
proxies['http'] = 'http://%s' % proxyServer
proxies['https'] = 'http://%s' % proxyServer
proxies['ftp'] = 'http://%s' % proxyServer
修改完的代码如下所示:
重新运行代码(注意继续保持科学上网):
成功下载模型!
4 参考
[1] Welcome to the SHAP documentation — SHAP latest documentation