def main(){ # Define a variable `a` with shape <2, 3>, initialized with the literal value. # The shape is inferred from the supplied literal. var a = [[1, 2, 3], [4, 5, 6]];
# b is identical to a, the literal tensor is implicitly reshaped: defining new # variables is the way to reshape tensors (element count must match). var b<2, 3> = [1, 2, 3, 4, 5, 6];
# transpose() and print() are the only builtin, the following will transpose # a and b and perform an element-wise multiplication before printing the result. print(transpose(a) * transpose(b)); }
# User defined generic function that operates on unknown shaped arguments. def multiply_transpose(a, b){ return transpose(a) * transpose(b); }
def main(){ # Define a variable `a` with shape <2, 3>, initialized with the literal value. var a = [[1, 2, 3], [4, 5, 6]]; var b<2, 3> = [1, 2, 3, 4, 5, 6];
# This call will specialize `multiply_transpose` with <2, 3> for both # arguments and deduce a return type of <3, 2> in initialization of `c`. var c = multiply_transpose(a, b);
# A second call to `multiply_transpose` with <2, 3> for both arguments will # reuse the previously specialized and inferred version and return <3, 2>. var d = multiply_transpose(b, a);
# A new call with <3, 2> (instead of <2, 3>) for both dimensions will # trigger another specialization of `multiply_transpose`. var e = multiply_transpose(b, c);
# Finally, calling into `multiply_transpose` with incompatible shape will # trigger a shape inference error. var f = multiply_transpose(transpose(a), c); }
然后我们可以使用下面的命令来产生这个Toy语言程序的AST:
cd llvm-project/build/bin ./toyc-ch1 ../../mlir/test/Examples/Toy/Ch1/ast.toy --emit=ast
/// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr){ switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Var: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Literal: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Call: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Num: return mlirGen(cast(expr)); default: emitError(loc(expr.loc())) returnnullptr; } }
/// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. mlir::Value mlirGen(CallExprAST &call){ llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc());
// Codegen the operands first. SmallVector<:value style="color: #1c00cf;line-height: 26px;">4> operands; for (auto &expr : call.getArgs()) { auto
arg = mlirGen(*expr); if (!arg) returnnullptr; operands.push_back(arg); }
// Builtin calls have their custom operation, meaning this is a // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { emitError(location, "MLIR codegen encountered an error: toy.transpose " "does not accept multiple arguments"); returnnullptr; } return builder.create(location, operands[0]); }
// Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. return builder.create(location, callee, operands); }
// Provide a definition of the 'toy' dialect in the ODS framework so that we // can define our operations. def Toy_Dialect : Dialect { let name = "toy"; let cppNamespace = "::mlir::toy"; }
// Provide a definition of the 'toy' dialect in the ODS framework so that we // can define our operations. def Toy_Dialect : Dialect { let name = "toy"; let cppNamespace = "::mlir::toy"; }
// Base class for toy dialect operations. This operation inherits from the base // `Op` class in OpBase.td, and provides: // * The parent dialect of the operation. // * The mnemonic for the operation, or the name without the dialect prefix. // * A list of traits for the operation. classToy_Op traits = []> : Op;
下面给出transpose Operation的定义感受一下:
def TransposeOp : Toy_Op<"transpose"> { let summary = "transpose operation";
let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor);