首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Halide JIT与Generator的差异

Halide JIT与Generator的差异
EN

Stack Overflow用户
提问于 2020-10-11 20:17:32
回答 1查看 144关注 0票数 0

在使用Halide时,我发现在使用JIT和生成函数方法时,会为相同的管线创建完全不同的伪代码。看起来我遗漏了一些东西,所以我非常感谢并给你一个提示。下面是我所做的:

一个简单的‘扩展’管线定义为:

代码语言:javascript
复制
int jit_main ()
{
    Target target = get_jit_target_from_environment ();
    const int width = 1280, height = 1024;
    Buffer <uint8_t> input (width, height);

    for (int y = 0; y < height; y++)
        for (int x = 0; x < width; x++)
            input (x, y) = rand () & 0xff;

    Var x ("x_1"), y ("y_1");

    Func clamped ("clamped_1");
    clamped = BoundaryConditions::repeat_edge (input);

    Func max_x ("max_x_1");
    max_x (x, y) = max (clamped (x - 1, y), clamped (x, y), clamped (x + 1, y));

    Func dilate ("dilate_1");
    dilate (x, y) = max (max_x (x, y - 1), max_x (x, y), max_x (x, y + 1));

    tick (NULL);
    Buffer<uint8_t> out = dilate.realize (width, height, target);
    tick ("inline");

    dilate.print_loop_nest ();
    dilate.compile_to_lowered_stmt ("dilate_1_.html", {}, HTML);
}

生成的伪代码如下所示(片段):

代码语言:javascript
复制
    produce dilate_1 {
        let t125 = ((dilate_1.min.1 * dilate_1.stride.1) + dilate_1.min.0)
        for (dilate_1.s0.y_1, dilate_1.min.1, dilate_1.extent.1) {
            let t128 = max(min(dilate_1.s0.y_1, 1024), 1)
            let t126 = max(min(dilate_1.s0.y_1, 1023), 0)
            let t127 = max(min(dilate_1.s0.y_1, 1022), -1)
            let t129 = ((dilate_1.s0.y_1 * dilate_1.stride.1) - t125)
            for (dilate_1.s0.x_1, dilate_1.min.0, dilate_1.extent.0) {
                dilate_1[(dilate_1.s0.x_1 + t129)] = max(b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t126 * 1280)) + 1)], max(b0[(max(min(dilate_1.s0.x_1, 1279), 0) + (t126 * 1280))], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t126 * 1280)) + -1)], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t127 * 1280)) + 1279)], max(b0[((max(min(dilate_1.s0.x_1, 1279), 0) + (t127 * 1280)) + 1280)], max(b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t127 * 1280)) + 1281)], max(b0[((max(min(dilate_1.s0.x_1, 1280), 1) + (t128 * 1280)) + -1281)], max(b0[((max(min(dilate_1.s0.x_1, 1279), 0) + (t128 * 1280)) + -1280)], b0[((max(min(dilate_1.s0.x_1, 1278), -1) + (t128 * 1280)) + -1279)]))))))))
            }
        }
    }

然后我定义了一个生成器:

代码语言:javascript
复制
class Dilate0Generator : public Halide::Generator <Dilate0Generator>
{
public:
    Input<Buffer<uint8_t>>  input_0 {"input_0", 2};
    Output<Buffer<uint8_t>> dilate_0 {"dilate_0", 2};
    Var                     x {"x_0"}, y {"y_0"};

    void generate ()
    {
        Func clamped_0 {"clamped_0"};
        clamped_0 = BoundaryConditions::repeat_edge (input_0);

        Func max_x_0 {"max_x_0"};
        max_x_0 (x, y) =
            max (clamped_0 (x - 1, y), clamped_0 (x, y), clamped_0 (x + 1, y));

        dilate_0 (x, y) =
            max (max_x_0 (x, y - 1), max_x_0 (x, y), max_x_0 (x, y + 1));

        dilate_0.print_loop_nest ();
    }
};
HALIDE_REGISTER_GENERATOR (Dilate0Generator, dilate_0)

它的伪代码是完全不同的(片段):

代码语言:javascript
复制
    produce dilate_0 {
        let dilate_0.s0.y_0.prologue = min(max((input_0.min.1 + 1), dilate_0.min.1), (dilate_0.extent.1 + dilate_0.min.1))
        let dilate_0.s0.y_0.epilogue$3 = min(max(max((input_0.min.1 + 1), dilate_0.min.1), ((input_0.extent.1 + input_0.min.1) + -1)), (dilate_0.extent.1 + dilate_0.min.1))
        let t166 = (dilate_0.s0.y_0.prologue - dilate_0.min.1)
        let t168 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
        let t170 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
        let t167 = (input_0.extent.1 + input_0.min.1)
        let t169 = (input_0.extent.0 + input_0.min.0)
        for (dilate_0.s0.y_0, dilate_0.min.1, t166) {
            let t171 = ((max(min((t167 + -1), dilate_0.s0.y_0), input_0.min.1) * input_0.stride.1) - t168)
            let t173 = ((max((min((dilate_0.s0.y_0 + 2), t167) + -1), input_0.min.1) * input_0.stride.1) - t168)
            let t174 = ((max((min(dilate_0.s0.y_0, t167) + -1), input_0.min.1) * input_0.stride.1) - t168)
            let t175 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t170)
            for (dilate_0.s0.x_0, dilate_0.min.0, dilate_0.extent.0) {
                dilate_0[(dilate_0.s0.x_0 + t175)] = (let t132 = max((min((dilate_0.s0.x_0 + 2), t169) + -1), input_0.min.0) in (let t133 = max(min((t169 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t134 = max((min(dilate_0.s0.x_0, t169) + -1), input_0.min.0) in max(input_0[(t132 + t171)], max(input_0[(t133 + t171)], max(input_0[(t134 + t171)], max(input_0[(t134 + t173)], max(input_0[(t133 + t173)], max(input_0[(t132 + t173)], max(input_0[(t134 + t174)], max(input_0[(t133 + t174)], input_0[(t132 + t174)])))))))))))
            }
        }
        let t183 = (dilate_0.extent.0 + dilate_0.min.0)
        let t184 = (input_0.extent.0 + input_0.min.0)
        let t185 = max((input_0.min.0 + 1), dilate_0.min.0)
        let t178 = min(max((t184 + -1), t185), t183)
        let t177 = min(t183, t185)
        let t176 = (dilate_0.s0.y_0.epilogue$3 - dilate_0.s0.y_0.prologue)
        let t179 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
        let t181 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
        for (dilate_0.s0.y_0, dilate_0.s0.y_0.prologue, t176) {
            let t189 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
            let t190 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
            let t187 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
            let t191 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
            let t186 = (t177 - dilate_0.min.0)
            for (dilate_0.s0.x_0, dilate_0.min.0, t186) {
                dilate_0[(dilate_0.s0.x_0 + t191)] = (let t140 = max((min((dilate_0.s0.x_0 + 2), t184) + -1), input_0.min.0) in (let t141 = max(min((t184 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t142 = max((min(dilate_0.s0.x_0, t184) + -1), input_0.min.0) in max(input_0[(t140 + t187)], max(input_0[(t141 + t187)], max(input_0[(t142 + t187)], max(input_0[(t142 + t189)], max(input_0[(t141 + t189)], max(input_0[(t140 + t189)], max(input_0[(t142 + t190)], max(input_0[(t141 + t190)], input_0[(t140 + t190)])))))))))))
            }
            let t194 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
            let t195 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
            let t193 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
            let t196 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
            let t192 = (t178 - t177)
            for (dilate_0.s0.x_0, t177, t192) {
                dilate_0[(dilate_0.s0.x_0 + t196)] = max(input_0[((dilate_0.s0.x_0 + t193) + 1)], max(input_0[(dilate_0.s0.x_0 + t193)], max(input_0[((dilate_0.s0.x_0 + t193) + -1)], max(input_0[((dilate_0.s0.x_0 + t194) + -1)], max(input_0[(dilate_0.s0.x_0 + t194)], max(input_0[((dilate_0.s0.x_0 + t194) + 1)], max(input_0[((dilate_0.s0.x_0 + t195) + -1)], max(input_0[(dilate_0.s0.x_0 + t195)], input_0[((dilate_0.s0.x_0 + t195) + 1)]))))))))
            }
            let t200 = (((dilate_0.s0.y_0 + 1) * input_0.stride.1) - t179)
            let t201 = (((dilate_0.s0.y_0 + -1) * input_0.stride.1) - t179)
            let t198 = ((dilate_0.s0.y_0 * input_0.stride.1) - t179)
            let t202 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t181)
            let t197 = (t183 - t178)
            for (dilate_0.s0.x_0, t178, t197) {
                dilate_0[(dilate_0.s0.x_0 + t202)] = (let t152 = max((min((dilate_0.s0.x_0 + 2), t184) + -1), input_0.min.0) in (let t153 = max(min((t184 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t154 = max((min(dilate_0.s0.x_0, t184) + -1), input_0.min.0) in max(input_0[(t152 + t198)], max(input_0[(t153 + t198)], max(input_0[(t154 + t198)], max(input_0[(t154 + t200)], max(input_0[(t153 + t200)], max(input_0[(t152 + t200)], max(input_0[(t154 + t201)], max(input_0[(t153 + t201)], input_0[(t152 + t201)])))))))))))
            }
        }
        let t203 = ((dilate_0.extent.1 + dilate_0.min.1) - dilate_0.s0.y_0.epilogue$3)
        let t205 = ((input_0.min.1 * input_0.stride.1) + input_0.min.0)
        let t207 = ((dilate_0.min.1 * dilate_0.stride.1) + dilate_0.min.0)
        let t204 = (input_0.extent.1 + input_0.min.1)
        let t206 = (input_0.extent.0 + input_0.min.0)
        for (dilate_0.s0.y_0, dilate_0.s0.y_0.epilogue$3, t203) {
            let t208 = ((max(min((t204 + -1), dilate_0.s0.y_0), input_0.min.1) * input_0.stride.1) - t205)
            let t210 = ((max((min((dilate_0.s0.y_0 + 2), t204) + -1), input_0.min.1) * input_0.stride.1) - t205)
            let t211 = ((max((min(dilate_0.s0.y_0, t204) + -1), input_0.min.1) * input_0.stride.1) - t205)
            let t212 = ((dilate_0.s0.y_0 * dilate_0.stride.1) - t207)
            for (dilate_0.s0.x_0, dilate_0.min.0, dilate_0.extent.0) {
                dilate_0[(dilate_0.s0.x_0 + t212)] = (let t161 = max((min((dilate_0.s0.x_0 + 2), t206) + -1), input_0.min.0) in (let t162 = max(min((t206 + -1), dilate_0.s0.x_0), input_0.min.0) in (let t163 = max((min(dilate_0.s0.x_0, t206) + -1), input_0.min.0) in max(input_0[(t161 + t208)], max(input_0[(t162 + t208)], max(input_0[(t163 + t208)], max(input_0[(t163 + t210)], max(input_0[(t162 + t210)], max(input_0[(t161 + t210)], max(input_0[(t163 + t211)], max(input_0[(t162 + t211)], input_0[(t161 + t211)])))))))))))
            }
        }
    }

生成的版本运行速度快了一个数量级,这并不奇怪,因为它的伪代码看起来要优化得多。它甚至比an existed example运行得更快

我的新手问题是,为什么JIT不能创建相同的表示?非常感谢你的回答/想法/帮助/提示...

EN

回答 1

Stack Overflow用户

发布于 2020-10-13 07:07:28

两者之间的区别在于,在JIT情况下,输入的大小(以及边界条件的位置)在编译时是已知的。

但是,生成的代码应该是相似的。我认为在JIT案例中没有五个单独的案例的事实是Halide中的一个错误。我已经在Halide github repo上打开了一个问题。https://github.com/halide/Halide/issues/5353

编辑:感谢您发现了一个bug!已在https://github.com/halide/Halide/pull/5355中修复

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64303915

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档