Skip to content

Commit 5baae4f

Browse files
committed
[mlir][utils] Update generate-test-checks.py (use SSA names)
This patch updates generate-test-checks.py to preserve original SSA names (capitalized) when generating LIT variable names for function arguments (i.e. for `CHECK-SAME` lines). This improves readability and helps maintain consistency between the input MLIR and the expected FileCheck/LIT output. For example, given the following function: ```mlir func.func @example( %input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) { linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>) outs(%output : memref<4x2x8xf32>) return } ``` The generated output becomes: ```mlir // CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref( // CHECK-SAME: %[[INPUT:.*]]: memref<4x6x3xf32>, // CHECK-SAME: %[[FILTER:.*]]: memref<1x3x8xf32>, // CHECK-SAME: %[[OUTPUT:.*]]: memref<4x2x8xf32>) { // CHECK: linalg.conv_1d_nwc_wcf // CHECK: {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} // CHECK: ins(%[[INPUT]], %[[FILTER]] : memref<4x6x3xf32>, memref<1x3x8xf32>) // CHECK: outs(%[[OUTPUT]] : memref<4x2x8xf32>) // CHECK: return // CHECK: } ``` By contrast, the current version of the script would generate: ```mlir // CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref( // CHECK-SAME: %[[VAL_0:.*]]: memref<4x6x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<1x3x8xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<4x2x8xf32>) { // CHECK: linalg.conv_1d_nwc_wcf // CHECK: {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} // CHECK: ins(%[[VAL_0]], %[[VAL_1]] : memref<4x6x3xf32>, memref<1x3x8xf32>) // CHECK: outs(%[[VAL_2]] : memref<4x2x8xf32>) // CHECK: return // CHECK: } ```
1 parent ca3a5d3 commit 5baae4f

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

mlir/utils/generate-test-checks.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22
"""A script to generate FileCheck statements for mlir unit tests.
3-
43
This script is a utility to add FileCheck patterns to an mlir file.
54
65
NOTE: The input .mlir is expected to be the output from the parser, not a
@@ -77,13 +76,16 @@ def generate_in_parent_scope(self, n):
7776
self.generate_in_parent_scope_left = n
7877

7978
# Generate a substitution name for the given ssa value name.
80-
def generate_name(self, source_variable_name):
79+
def generate_name(self, source_variable_name, use_ssa_name):
8180

8281
# Compute variable name
8382
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
8483
if variable_name == '':
85-
variable_name = "VAL_" + str(self.name_counter)
86-
self.name_counter += 1
84+
if use_ssa_name:
85+
variable_name = source_variable_name.upper()
86+
else:
87+
variable_name = "VAL_" + str(self.name_counter)
88+
self.name_counter += 1
8789

8890
# Scope where variable name is saved
8991
scope = len(self.scopes) - 1
@@ -158,7 +160,7 @@ def get_num_ssa_results(input_line):
158160

159161

160162
# Process a line of input that has been split at each SSA identifier '%'.
161-
def process_line(line_chunks, variable_namer, strict_name_re=False):
163+
def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
162164
output_line = ""
163165

164166
# Process the rest that contained an SSA value name.
@@ -178,7 +180,7 @@ def process_line(line_chunks, variable_namer, strict_name_re=False):
178180
output_line += "%[[" + variable + "]]"
179181
else:
180182
# Otherwise, generate a new variable.
181-
variable = variable_namer.generate_name(ssa_name)
183+
variable = variable_namer.generate_name(ssa_name, use_ssa_name)
182184
if strict_name_re:
183185
# Use stricter regexp for the variable name, if requested.
184186
# Greedy matching may cause issues with the generic '.*'
@@ -415,9 +417,11 @@ def main():
415417
pad_depth = label_length if label_length < 21 else 4
416418
output_line += " " * pad_depth
417419

418-
# Process the rest of the line.
420+
# Process the rest of the line. Use the original SSA name to generate the LIT
421+
# variable names.
422+
use_ssa_names=True
419423
output_line += process_line(
420-
[argument], variable_namer, args.strict_name_re
424+
[argument], variable_namer, use_ssa_names, args.strict_name_re
421425
)
422426

423427
# Append the output line.

0 commit comments

Comments
 (0)