首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何派生已经从Halide::生成器派生的类?

如何派生已经从Halide::生成器派生的类?
EN

Stack Overflow用户
提问于 2021-11-22 12:26:13
回答 1查看 138关注 0票数 1

我想在Halide/C++中创建一个基于Halide::生成器的基本继承结构,以避免重复的代码。

其思想是拥有一个拥有纯虚拟函数的抽象基类生成器类。此外,每个派生类都应该有一个特定的输入参数,在基类中不可用。

在普通的C++中,这是非常简单的,但是由于Halide是一个在链接和编译之前“生成代码”的DSL,所以事情可能会变得有点混乱。

我当前的Halide实现都在一个文件中:

my_generators.cpp

代码语言:javascript
复制
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

class Base : public Halide::Generator<Base> {
public:
    Input<Buffer<float>> input{"input", 2};

    Output<Buffer<float>> output{"brighter", 2};

    Var x, y;

    virtual Func process(Func input) = 0;

    virtual void generate() {
        output = process(input);
        output.vectorize(x, 16).parallel(y);
    }
};

class DerivedGain : public Base {
    public:
    Input<float> gain{"gain"};

    Func process (Func input) override{
        Func result("result");
        result(x,y) = input(x,y) * gain;
        return result;
    }
};

class DerivedOffset : public Base{
    public:
    Input<float> offset{"offset"};

    Func process (Func input) override{
        Func result("result");
        result(x,y) = input(x,y) + offset;
        return result;
    }
};

HALIDE_REGISTER_GENERATOR(DerivedGain, derived_gain)
HALIDE_REGISTER_GENERATOR(DerivedOffset, derived_offset)

为了编译它,我使用了以下CMakeLists文件:

代码语言:javascript
复制
cmake_minimum_required(VERSION 3.16)
project(HalideExample)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED YES)
set(CMAKE_CXX_EXTENSIONS NO)

find_package(Halide REQUIRED)

add_executable(my_generators my_generators.cpp)

target_include_directories(my_generators PUBLIC ${HALIDE_ROOT}/include)
target_link_libraries(my_generators PRIVATE Halide::Generator)

add_halide_library(derived_gain FROM my_generators)
add_halide_library(derived_offset FROM my_generators)

我一直在使用预先构建的版本Halide-13.0.1-x86-64-linux 这里

但是在编译过程中,它会启动一个错误,表明class Base正在实例化(我不需要这样做):

代码语言:javascript
复制
In file included from <path_to_project>/my_generators.cpp:2:
<path_to_halide>/include/Halide.h: In instantiation of ‘static std::unique_ptr<_Tp> Halide::Generator<T>::create(const Halide::GeneratorContext&) [with T = Base]’:
<path_to_halide>/include/Halide.h:26640:14:   required from ‘static std::unique_ptr<_Tp> Halide::Generator<T>::create(const Halide::GeneratorContext&, const string&, const string&) [with T = Base; std::string = std::__cxx11::basic_string<char>]’
<path_to_project>/my_generators.cpp:53:1:   required from here
<path_to_halide>/include/Halide.h:26631:37: error: invalid new-expression of abstract class type ‘Base’
26631 |         auto g = std::unique_ptr<T>(new T());
      |                                     ^~~~~~~
<path_to_project>/my_generators.cpp:7:7: note:   because the following virtual functions are pure within ‘Base’:
    7 | class Base : public Halide::Generator<Base> {
      |       ^~~~
<path_to_project>/my_generators.cpp:15:18: note:    ‘virtual Halide::NamesInterface::Func Base::process(Halide::NamesInterface::Func)’
   15 |     virtual Func process(Func input) = 0;

如果不是使用virtual函数,而是在基类中实现它,如下所示:

代码语言:javascript
复制
class Base : public Halide::Generator<Base> {
public:
    Input<Buffer<float>> input{"input", 2};

    Output<Buffer<float>> output{"brighter", 2};

    Var x, y;

    // Func process(Func input);
    Func process (Func input){
        Func result("result");
        result(x,y) = input(x,y);
        return result;
    }

    virtual void generate() {
        output = process(input);
        output.vectorize(x, 16).parallel(y);
    }
};

然后所有文件都编译,但是带有生成代码的对象和头文件有错误的函数签名(由于缺少增益/偏移参数而值得注意):

derived_gain.h:

代码语言:javascript
复制
int derived_gain(struct halide_buffer_t *_input_buffer, struct halide_buffer_t *_result_buffer);

derived_offset.h:

代码语言:javascript
复制
int derived_offset(struct halide_buffer_t *_input_buffer, struct halide_buffer_t *_result_buffer);

因此,我想知道我在类别定义中引入了哪一个错误,以及如何解决这个错误。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-22 19:09:15

可以将基类转换为模板:

代码语言:javascript
复制
template<class T>
class Base : public Halide::Generator<T> {

然后再导出InputOutput的名字.(我没有足够的C++专家来理解为什么这是必要的):

代码语言:javascript
复制
  // In class Base:
  template <typename T2>
  using Input = typename Halide::Generator<T>::template Input<T2>;

  template <typename T2>
  using Output = typename Halide::Generator<T>::template Output<T2>;

那么剩下的修改就是:

代码语言:javascript
复制
class DerivedGain : public Base<DerivedGain> { ... };
class DerivedOffset : public Base<DerivedOffset> { ... };

这似乎对我有用。

而且,您可能不需要在CMakeLists.txt中使用这一行(我不需要):

代码语言:javascript
复制
target_include_directories(my_generators PUBLIC ${HALIDE_ROOT}/include)

我们的包没有设置HALIDE_ROOT,而且链接到Halide::Generator已经正确地设置了包含路径。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70065723

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档