接下来我们将向Dialect中添加一些操作
添加操作的步骤
向Dialect中添加操作的步骤一般遵循以下的模板:
- 定义属于这个Dialect的操作的基类
- 从操作的基类中派生出具体的操作类
- 提供操作的描述, 操作的属性, 操作数类型等基本信息
- 在Dialect中注册操作
// 提供Toy方言的操作的基类
def Toy_Op<string mnemonic, list<Trait> traits = []> : Op<Toy_Dialect, mnemonic, traits>;
// 定义Toy方言的操作
def ConstantOp : Toy_Op<"constant", [Pure]> {
let summary = "constant value";
let description = [{
Constant operation that produces a constant value.
}];
let arguments = (ins IntegerAttr:$value);
let results = (outs IntegerType:$result);
}
def PrintOp : Toy_Op<"print", [Pure]> {
let summary = "print value";
let description = [{
Print operation that prints a value to the console.
}];
let arguments = (ins IntegerAttr:$input);
let hasCustomAssemblyFormat = 1;
let assemblyFormat = "$input attr-dict `:` type($input)";
}
需要注意的是, 在定义操作名称的时候不要用_
分隔, 即PrintOp
不要写成Print_Op
, 否则TableGen自动生成的代码会有问题, 生成的操作类名称就是Op
, 和基类名称冲突(别问怎么知道的).
上面的代码定义了两个操作, 一个常量操作, 它不接受任何输入, 只固定地返回一个值, 而这个值是在编译之前就指定好的. 另一个是打印操作, 它接受一个整数输入, 并且不定义任何返回值. hasCustomAssemblyFormat
表示这个操作有自定义的汇编格式. MLIR允许我们为操作自定义生成的汇编格式, 但需要我们额外实现一个parse
函数告诉MLIR框架如何解析自定义格式, 我们留到稍后讲.
总之, 上面的代码已经实现了添加新操作的前三步, 非常简洁快速. 而第四步实际上我们早已做过, 也就是在ToyDialect.cpp
中添加Dialect的initialize
代码的时候.
using namespace mlir::toy;
void ToyDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "ToyOps.cpp.inc"
>();
}
到此实际上就走完了添加操作的所有基本流程. 如你所见, 使用ODS可以非常快速的定义一个操作, 然后生成相关的C++代码. 接下来大致解释一下在定义操作的时候用到的一些东西.
操作的描述
summary
和description
是操作的描述, 这两个字段会在生成的文档中使用. 字面意思, summary
是简短的描述, description
是详细的描述. 生成文档可以通过mlir-tblgen --gen-op-doc
命令实现.
操作的参数
在操作的arguments
字段中定义操作的参数, 这里的参数既可以是操作数, 也可以是属性. 简单回忆一下, 操作的属性指的是在编译期就已经知道的"预定义"固定值, 而操作数的值只有在运行时才能确定. 比如我们希望ConstantOp
始终返回一个固定的整数1
, 那这个操作的属性列表中就可以有一项{ value = 1 }
, 在TableGen中可以用IntegerAttr:$value
来捕获这个属性, 属性值被记作$value
. 而我们希望PrintOp
能接受不同的整数输入, 并将整数打印出来, 因此PrintOp
应该有一个操作数类似(%arg0: i32)
, 在TableGen中可以用IntegerAttr:$input
来捕获这个操作数, 操作数值被记作$input
.
// Use Examples
%0 = "toy.constant"() { value = 1 } : () -> i32 // 固定返回`1`
%1 = "toy.constant"() { value = 2 } : () -> i32 // 固定返回`2`
"toy.print" %0 : i32 // 打印结果为`1`