从TVM的官方Tutorial里面,介绍了如何新增自定义算子。(这是我翻译的)
之前的文章讲到了onnx 算子转换到Relay IR的过程 下面以Conv2d算子介绍,编译过程中 Relay IR是如何被调用的。
relay 算子调用
上面的get_relay_op实际上是查找所有 relay ir算子,其代码在python/tvm/relay/frontend/common.py中的get_relay_op。继续以conv卷积算子为例介绍。上文所述的转换算子中,有下面的语句
1 | for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib): |
对于conv2d算子,在_op.nn中,找到conv2d实现。
1 | def conv2d( |
这里的_make.conv2d是通过下面的PackFunc注册得到的
1 | tvm._ffi._init_api("relay.op.nn._make", __name__) |
在src/relay/op/nn/convolution.cc找到conv2d的注册函数
1 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d") |
MakeConv 是对所有卷积的模板,根据参数实例化相应的函数
1 | template <typename T> |
这里通过Op::Get(op_name); 获取对应relay算子,在Op::Get函数中发现是通过查表得到。
1
2
3
4
5
6// find operator by name
const Op& Op::Get(const String& name) {
const OpRegEntry* reg = OpRegistry::Global()->Get(name);
ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered";
return reg->op();
}
注册是通过C++的RELAY_REGISTER_OP("nn.conv2d")宏注册到OpRegistry::Global()中。宏展开为
1 | static __attribute__((unused))::tvm::OpRegEntry& __make_Op230 = |
注册过程:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21RELAY_REGISTER_OP("nn.conv2d")
.describe(R"code(2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);OpRegEntry,后续的set_name等,则是通过OpRegEntry的get接口(返回的是OpNode),构造对应的Relay op