首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >"ValueError: max_evals=500太低了,不能解释排列“,shap回答我,我需要提供更多的数据(照片)吗?

"ValueError: max_evals=500太低了,不能解释排列“,shap回答我,我需要提供更多的数据(照片)吗?
EN

Stack Overflow用户
提问于 2022-09-08 11:40:24
回答 1查看 422关注 0票数 1

我想测试一种多类语义分割模型deeplab_v3plus的可解释性,以了解哪些特征对语义分类贡献最大。但是,在运行我的文件时,我有一个ValueError: max_evals=500 is too low,我很难理解其中的原因。

代码语言:javascript
复制
import glob
from PIL import Image

import torch
from torchvision import transforms
from torchvision.utils import make_grid
import torchvision.transforms.functional as tf

from deeplab import deeplab_v3plus

import shap

def test(args):
    # make a video prez
    
    model = deeplab_v3plus('resnet101', num_classes=args.nclass, output_stride=16, pretrained_backbone=True)
    model.load_state_dict(torch.load(args.seg_file,map_location=torch.device('cpu'))) # because no gpu available on sandbox environnement
    model = model.to(args.device)
    model.eval()
    explainer = shap.Explainer(model)
    with torch.no_grad():
        for i, file in enumerate(args.img_folder):
            img = img2tensor(file, args)
    
            pred = model(img)
            print(explainer(img))

if __name__ == '__main__':
    class Arguments:
        def __init__(self):
            self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
            self.seg_file = "Model_Woodscape.pth"
            self.img_folder = glob.glob("test_img/*.png")
            self.mean = [0.485, 0.456, 0.406]
            self.std = [0.229, 0.224, 0.225]
            self.h, self.w = 483, 640
            self.nclass = 10
            self.cmap = {
                1: [128, 64, 128],  # "road",
                2: [69, 76, 11],    # "lanemarks",
                3: [0, 255, 0],     # "curb",
                4: [220, 20, 60],   # "person",
                5: [255, 0, 0],     # "rider",
                6: [0, 0, 142],     # "vehicles",
                7: [119, 11, 32],   # "bicycle",
                8: [0, 0, 230],     # "motorcycle",
                9: [220, 220, 0],   # "traffic_sign",
                0: [0, 0, 0]        # "void"
            }

    args = Arguments()
    test(args)

但它会返回:

代码语言:javascript
复制
(dee_env) jovyan@jupyter:~/use-cases/Scene_understanding/Code_Woodscape/deeplab_v3+$ python test_shap.py 
BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
Traceback (most recent call last):
  File "/home/jovyan/use-cases/Scene_understanding/Code_Woodscape/deeplab_v3+/test_shap.py", line 85, in <module>
    test(args)
  File "/home/jovyan/use-cases/Scene_understanding/Code_Woodscape/deeplab_v3+/test_shap.py", line 37, in test
    print(explainer(img))
  File "/home/jovyan/use-cases/Scene_understanding/Code_Woodscape/deeplab_v3+/dee_env/lib/python3.9/site-packages/shap/explainers/_permutation.py", line 82, in __call__
    return super().__call__(
  File "/home/jovyan/use-cases/Scene_understanding/Code_Woodscape/deeplab_v3+/dee_env/lib/python3.9/site-packages/shap/explainers/_explainer.py", line 266, in __call__
    row_result = self.explain_row(
  File "/home/jovyan/use-cases/Scene_understanding/Code_Woodscape/deeplab_v3+/dee_env/lib/python3.9/site-packages/shap/explainers/_permutation.py", line 164, in explain_row
    raise ValueError(f"max_evals={max_evals} is too low for the Permutation explainer, it must be at least 2 * num_features + 1 = {2 * len(inds) + 1}!")
ValueError: max_evals=500 is too low for the Permutation explainer, it must be at least 2 * num_features + 1 = 1854721!

在源代码中,这似乎是因为我没有给出足够的参数。我的test_img/*文件夹中只有三张图片,这是为什么吗?

EN

回答 1

Stack Overflow用户

发布于 2022-09-26 15:47:54

我也有同样的问题。我找到的一个可能的解决方案似乎适用于我的情况,那就是替换这一行。

代码语言:javascript
复制
explainer = shap.Explainer(model)

用这条线

代码语言:javascript
复制
explainer = shap.explainers.Permutation(model, max_evals = 1854721)

默认情况下,shap.Explainer有算法=‘auto’。来自文档:shape.Explainer

默认情况下,“auto”选项试图在传递的模型和掩蔽符中做出最佳选择,但通过传递特定算法的名称,这种选择总是会被覆盖。

由于“置换”已被选中,您可以直接使用shap.explainers.Permutation并将max_evals设置为上面错误消息中建议的值。考虑到用例的数量很大,这可能需要很长时间。我建议使用更简单的模型来测试上面的解决方案。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73648498

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档