Skip to content

Commit 3a909f9

Browse files
author
Ettore Tiotto
authored
Add KrnlPrintTensorOp and KrnlPrintOp to facilitate debugging. (llvm#1180)
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 3a45ee5 commit 3a909f9

20 files changed

+856
-287
lines changed

docs/DebuggingNumericalError.md

+18
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,21 @@ optional arguments:
8282
--rtol RTOL Relative tolerance for verification
8383
--atol ATOL Absolute tolerance for verification
8484
```
85+
86+
## Debugging the Code Generated for an Operator.
87+
88+
If you know, or suspect, that a particular ONNX MLIR operator produces an incorrect result, and want to narrow down the problem, we provide a couple of useful Krnl operators that allow printing (at runtime) the value of a tensor, or a value that has a primitive data type.
89+
90+
To print out the value of a tensor at a particular program point, inject the following code (where `X` is the tensor to be printed):
91+
92+
```code
93+
create.krnl.printTensor("Tensor X: ", X);
94+
```
95+
96+
Note: currently the content of the tensor is printed only when the tensor rank is less than four.
97+
98+
To print a message followed by one value, inject the following code (where `val` is the value to be printed and `valType` is its type):
99+
100+
```code
101+
create.krnl.printf("inputElem: ", val, valType);
102+
```

include/onnx-mlir/Runtime/OMTensor.h

+8
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ int64_t omTensorGetOwning(const OMTensor *tensor);
284284
*/
285285
void omTensorSetOwning(OMTensor *tensor, int64_t owning);
286286

287+
/**
288+
* Print an OMTensor to stdout.
289+
*
290+
* @param msg, pointer to descriptive string
291+
* @param tensor, pointer to the OMTensor to print
292+
*/
293+
void omTensorPrint(const char *msg, const OMTensor *tensor);
294+
287295
#ifdef __cplusplus
288296
}
289297
#endif

include/onnx-mlir/Runtime/OnnxDataType.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#endif
2727

2828
enum OM_DATA_TYPE {
29-
#define OM_TYPE_METADATA_DEF(ENUM_NAME, ENUM_VAL, DTYPE_SIZE) \
29+
#define OM_TYPE_METADATA_DEF(ENUM_NAME, ENUM_VAL, DTYPE_SIZE, DTYPE_NAME) \
3030
ENUM_NAME = ENUM_VAL,
3131
#include "OnnxDataTypeMetaData.inc"
3232

@@ -38,6 +38,7 @@ typedef enum OM_DATA_TYPE OM_DATA_TYPE;
3838
#endif
3939

4040
extern const int OM_DATA_TYPE_SIZE[];
41+
extern const char *OM_DATA_TYPE_NAME[];
4142

4243
#ifdef __cplusplus
4344
// Note by design const map has no [] operator since [] creates a default
@@ -55,6 +56,9 @@ const std::map<std::string, OM_DATA_TYPE> OM_DATA_TYPE_CPP_TO_ONNX = {
5556
{"m", ONNX_TYPE_UINT64}, // uint64_t -> UINT64, unsigned long -> UINT64
5657
{"f", ONNX_TYPE_FLOAT}, // float -> FLOAT
5758
{"d", ONNX_TYPE_DOUBLE}, // double -> DOUBLE
59+
{"PKc", ONNX_TYPE_STRING}, // const char * -> STRING
60+
{"Cf", ONNX_TYPE_COMPLEX64}, // _Complex float -> COMPLEX64
61+
{"Cd", ONNX_TYPE_COMPLEX128}, // _Complex double -> COMPLEX128
5862
};
5963
#endif //__cplusplus
6064

include/onnx-mlir/Runtime/OnnxDataTypeMetaData.inc

+19-19
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,28 @@
33
*/
44

55
#if defined(OM_TYPE_METADATA_DEF)
6+
// clang-format off
67
// Data type metadata declared in the following format:
7-
// OM_TYPE_METADATA_DEF( dtype enum name, dtype enum value, dtype size)
8+
// OM_TYPE_METADATA_DEF(dtype enum name, dtype enum value, dtype size, dtype name)
89
// dtype enum values are standard ONNX data types defined in
910
// https://github.com/onnx/onnx/blob/main/onnx/onnx.proto#L484
10-
// clang-format off
11-
OM_TYPE_METADATA_DEF(ONNX_TYPE_UNDEFINED, 0, 0)
12-
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT, 1, sizeof(float))
13-
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT8, 2, sizeof(uint8_t))
14-
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT8, 3, sizeof(int8_t))
15-
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT16, 4, sizeof(uint16_t))
16-
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT16, 5, sizeof(int16_t))
17-
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT32, 6, sizeof(int32_t))
18-
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT64, 7, sizeof(int64_t))
19-
OM_TYPE_METADATA_DEF(ONNX_TYPE_STRING, 8, 0)
20-
OM_TYPE_METADATA_DEF(ONNX_TYPE_BOOL, 9, sizeof(bool))
21-
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT16, 10, 2)
22-
OM_TYPE_METADATA_DEF(ONNX_TYPE_DOUBLE, 11, sizeof(double))
23-
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT32, 12, sizeof(uint32_t))
24-
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT64, 13, sizeof(uint64_t))
25-
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX64, 14, 8)
26-
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX128, 15, 16)
27-
OM_TYPE_METADATA_DEF(ONNX_TYPE_BFLOAT16, 16, 2)
11+
OM_TYPE_METADATA_DEF(ONNX_TYPE_UNDEFINED, 0, 0, "undefined")
12+
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT, 1, sizeof(float), "float")
13+
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT8, 2, sizeof(uint8_t), "uint8_t")
14+
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT8, 3, sizeof(int8_t), "int8_t")
15+
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT16, 4, sizeof(uint16_t), "uint16_t")
16+
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT16, 5, sizeof(int16_t), "int16_t")
17+
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT32, 6, sizeof(int32_t), "int32_t")
18+
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT64, 7, sizeof(int64_t), "int64_t")
19+
OM_TYPE_METADATA_DEF(ONNX_TYPE_STRING, 8, 0, "const char *")
20+
OM_TYPE_METADATA_DEF(ONNX_TYPE_BOOL, 9, sizeof(bool), "_Bool")
21+
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT16, 10, 2, "")
22+
OM_TYPE_METADATA_DEF(ONNX_TYPE_DOUBLE, 11, sizeof(double), "double")
23+
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT32, 12, sizeof(uint32_t), "uint32_t")
24+
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT64, 13, sizeof(uint64_t), "uint64_t")
25+
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX64, 14, 8, "_Complex float")
26+
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX128, 15, 16, "_Complex double")
27+
OM_TYPE_METADATA_DEF(ONNX_TYPE_BFLOAT16, 16, 2, "")
2828
// clang-format on
2929
#else
3030
#error "Must define OM_TYPE_METADATA_DEF macro."

src/Conversion/KrnlToLLVM/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
add_onnx_mlir_library(OMKrnlToLLVM
44
KrnlToLLVM.cpp
5+
KrnlToLLVMHelper.cpp
6+
KrnlPrintTensor.cpp
7+
KrnlPrint.cpp
58
RuntimeAPI.cpp
69

710
LINK_LIBS PUBLIC
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===------ KrnlPrint.cpp - Lower KrnlPrintOp -----------------------------===//
6+
//
7+
// Copyright 2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file lowers the KrnlPrintOp operator.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
16+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
19+
20+
#include "src/Conversion/KrnlToLLVM/KrnlPrint.hpp"
21+
#include "src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp"
22+
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
23+
#include "src/Dialect/Krnl/KrnlHelper.hpp"
24+
#include "src/Dialect/Krnl/KrnlOps.hpp"
25+
26+
#include "llvm/Support/Debug.h"
27+
28+
#define DEBUG_TYPE "krnl_to_llvm"
29+
30+
using namespace mlir;
31+
32+
namespace onnx_mlir {
33+
34+
LogicalResult KrnlPrintOpLowering::matchAndRewrite(Operation *op,
35+
ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const {
36+
auto printOp = cast<KrnlPrintOp>(op);
37+
Location loc = printOp.getLoc();
38+
KrnlPrintOpAdaptor operandAdaptor(operands);
39+
40+
Value input = operandAdaptor.input();
41+
StringRef format = printOp.format();
42+
ModuleOp module = printOp->getParentOfType<ModuleOp>();
43+
44+
// Get a symbol reference to the runtime function to use, creating one if
45+
// necessary.
46+
auto printfFuncRef = getOrInsertPrintf(rewriter, module);
47+
48+
// Printf call.
49+
LLVM::GlobalOp formatSpec = getOrCreateGlobalString(format, loc, rewriter,
50+
module, static_cast<LLVMTypeConverter *>(getTypeConverter()));
51+
Value formatSpecPtr = getPtrToGlobalString(formatSpec, loc, rewriter);
52+
53+
if (input)
54+
rewriter.create<CallOp>(loc, printfFuncRef, ArrayRef<Type>({}),
55+
ArrayRef<Value>({formatSpecPtr, input}));
56+
else
57+
rewriter.create<CallOp>(loc, printfFuncRef, ArrayRef<Type>({}),
58+
ArrayRef<Value>({formatSpecPtr}));
59+
60+
rewriter.eraseOp(op);
61+
return success();
62+
}
63+
64+
FlatSymbolRefAttr KrnlPrintOpLowering::getOrInsertPrintf(
65+
PatternRewriter &rewriter, ModuleOp module) {
66+
// Insert the printf declaration if it is not already present.
67+
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
68+
MLIRContext *ctx = rewriter.getContext();
69+
70+
if (!printfFunc) {
71+
OpBuilder::InsertionGuard guard(rewriter);
72+
rewriter.setInsertionPointToStart(module.getBody());
73+
auto voidType = LLVM::LLVMVoidType::get(ctx);
74+
Type i8Type = IntegerType::get(ctx, 8);
75+
Type i8PtrType = LLVM::LLVMPointerType::get(i8Type);
76+
printfFunc =
77+
rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), "printf",
78+
LLVM::LLVMFunctionType::get(voidType, i8PtrType,
79+
/*isVarArg=*/true));
80+
}
81+
return SymbolRefAttr::get(ctx, "printf");
82+
}
83+
84+
} // namespace onnx_mlir
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===------ KrnlPrint.hpp - Lower KrnlPrintOp -----------------------------===//
6+
//
7+
// Copyright 2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file declares the lowering class for the KrnlPrintOp operator.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#pragma once
16+
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/Pass/Pass.h"
20+
21+
#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp"
22+
#include "src/Dialect/Krnl/KrnlOps.hpp"
23+
#include "src/Support/Common.hpp"
24+
25+
using namespace mlir;
26+
27+
namespace onnx_mlir {
28+
29+
class KrnlPrintOpLowering : public ConversionPattern {
30+
public:
31+
explicit KrnlPrintOpLowering(
32+
MLIRContext *context, TypeConverter &typeConverter)
33+
: ConversionPattern(
34+
typeConverter, KrnlPrintOp::getOperationName(), 1, context) {}
35+
36+
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
37+
ConversionPatternRewriter &rewriter) const override;
38+
39+
private:
40+
static FlatSymbolRefAttr getOrInsertPrintf(
41+
PatternRewriter &rewriter, ModuleOp module);
42+
};
43+
44+
} // namespace onnx_mlir
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===------ KrnlPrintTensor.cpp - Lower KrnlPrintTensorOp ----------------===//
6+
//
7+
// Copyright 2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file lowers the KrnlPrintTensorOp operator.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
16+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
19+
20+
#include "onnx/onnx_pb.h"
21+
22+
#include "src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp"
23+
#include "src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp"
24+
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
25+
#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp"
26+
#include "src/Dialect/Krnl/KrnlHelper.hpp"
27+
#include "src/Dialect/Krnl/KrnlOps.hpp"
28+
29+
#include "llvm/Support/Debug.h"
30+
31+
#define DEBUG_TYPE "krnl_to_llvm"
32+
33+
using namespace mlir;
34+
35+
namespace onnx_mlir {
36+
37+
LogicalResult KrnlPrintTensorOpLowering::matchAndRewrite(Operation *op,
38+
ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const {
39+
auto printTensorOp = cast<KrnlPrintTensorOp>(op);
40+
MLIRContext *context = printTensorOp.getContext();
41+
Location loc = printTensorOp.getLoc();
42+
KrnlPrintTensorOpAdaptor operandAdaptor(operands);
43+
44+
StringRef msg = printTensorOp.msg();
45+
Value input = operandAdaptor.input();
46+
assert(input.getType().isa<LLVM::LLVMStructType>() &&
47+
"expecting LLVMStructType");
48+
49+
ModuleOp module = printTensorOp->getParentOfType<ModuleOp>();
50+
const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter);
51+
52+
// Get a symbol reference to the runtime function to use, creating one if
53+
// necessary.
54+
auto int64Ty = IntegerType::get(context, 64);
55+
auto memRefTy = input.getType().dyn_cast<LLVM::LLVMStructType>();
56+
auto memRefRank = onnx_mlir::getRankFromMemRefType(memRefTy);
57+
auto memRefRankVal = rewriter.create<LLVM::ConstantOp>(
58+
loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank));
59+
Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry,
60+
RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal});
61+
62+
onnx_mlir::fillOMTensorWithMemRef(
63+
input, omTensor, false /*outOwning*/, rewriter, loc, apiRegistry, module);
64+
LLVM::GlobalOp globalStr = getOrCreateGlobalString(msg, loc, rewriter, module,
65+
static_cast<LLVMTypeConverter *>(getTypeConverter()));
66+
Value strPtr = getPtrToGlobalString(globalStr, loc, rewriter);
67+
68+
RuntimeAPI::callApi(rewriter, loc, apiRegistry,
69+
RuntimeAPI::API::PRINT_OMTENSOR, {strPtr, omTensor});
70+
71+
rewriter.eraseOp(op);
72+
return success();
73+
}
74+
75+
} // namespace onnx_mlir
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===------ KrnlPrintTensor.hpp - Lower KrnlPrintTensorOp -----------------===//
6+
//
7+
// Copyright 2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file declares the lowering class for the KrnlPrintTensorOp operator.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#pragma once
16+
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/Pass/Pass.h"
20+
21+
#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp"
22+
#include "src/Dialect/Krnl/KrnlOps.hpp"
23+
#include "src/Support/Common.hpp"
24+
25+
using namespace mlir;
26+
27+
namespace onnx_mlir {
28+
29+
class KrnlPrintTensorOpLowering : public ConversionPattern {
30+
public:
31+
explicit KrnlPrintTensorOpLowering(
32+
MLIRContext *context, TypeConverter &typeConverter)
33+
: ConversionPattern(
34+
typeConverter, KrnlPrintTensorOp::getOperationName(), 1, context) {}
35+
36+
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
37+
ConversionPatternRewriter &rewriter) const override;
38+
};
39+
40+
} // namespace onnx_mlir

0 commit comments

Comments
 (0)