Skip to content

Commit 85f737e

Browse files
banach-spaceIanWood1
authored andcommitted
[mlir][utils] Update generate-test-checks.py (use SSA names) (llvm#136819)
This patch updates generate-test-checks.py to preserve original SSA names (capitalized) when generating LIT variable names for function arguments (i.e. for `CHECK-SAME` lines). This improves readability and helps maintain consistency between the input MLIR and the expected FileCheck/LIT output. For example, given the following function: ```mlir func.func @example( %input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) { linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>) outs(%output : memref<4x2x8xf32>) return } ``` The generated output becomes: ```mlir // CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref( // CHECK-SAME: %[[INPUT:.*]]: memref<4x6x3xf32>, // CHECK-SAME: %[[FILTER:.*]]: memref<1x3x8xf32>, // CHECK-SAME: %[[OUTPUT:.*]]: memref<4x2x8xf32>) { // CHECK: linalg.conv_1d_nwc_wcf // CHECK: {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} // CHECK: ins(%[[INPUT]], %[[FILTER]] : memref<4x6x3xf32>, memref<1x3x8xf32>) // CHECK: outs(%[[OUTPUT]] : memref<4x2x8xf32>) // CHECK: return // CHECK: } ``` By contrast, the current version of the script would generate: ```mlir // CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref( // CHECK-SAME: %[[VAL_0:.*]]: memref<4x6x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<1x3x8xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<4x2x8xf32>) { // CHECK: linalg.conv_1d_nwc_wcf // CHECK: {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} // CHECK: ins(%[[VAL_0]], %[[VAL_1]] : memref<4x6x3xf32>, memref<1x3x8xf32>) // CHECK: outs(%[[VAL_2]] : memref<4x2x8xf32>) // CHECK: return // CHECK: } ```
1 parent bf4cd4c commit 85f737e

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

mlir/utils/generate-test-checks.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,20 @@ def generate_in_parent_scope(self, n):
7777
self.generate_in_parent_scope_left = n
7878

7979
# Generate a substitution name for the given ssa value name.
80-
def generate_name(self, source_variable_name):
80+
def generate_name(self, source_variable_name, use_ssa_name):
8181

8282
# Compute variable name
8383
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
8484
if variable_name == '':
85-
variable_name = "VAL_" + str(self.name_counter)
86-
self.name_counter += 1
85+
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
86+
# a FileCHeck substation string. As FileCheck requires these
87+
# strings to start with a character, skip MLIR variables starting
88+
# with a digit (e.g. `%0`).
89+
if use_ssa_name and source_variable_name[0].isalpha():
90+
variable_name = source_variable_name.upper()
91+
else:
92+
variable_name = "VAL_" + str(self.name_counter)
93+
self.name_counter += 1
8794

8895
# Scope where variable name is saved
8996
scope = len(self.scopes) - 1
@@ -158,7 +165,7 @@ def get_num_ssa_results(input_line):
158165

159166

160167
# Process a line of input that has been split at each SSA identifier '%'.
161-
def process_line(line_chunks, variable_namer, strict_name_re=False):
168+
def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
162169
output_line = ""
163170

164171
# Process the rest that contained an SSA value name.
@@ -178,7 +185,7 @@ def process_line(line_chunks, variable_namer, strict_name_re=False):
178185
output_line += "%[[" + variable + "]]"
179186
else:
180187
# Otherwise, generate a new variable.
181-
variable = variable_namer.generate_name(ssa_name)
188+
variable = variable_namer.generate_name(ssa_name, use_ssa_name)
182189
if strict_name_re:
183190
# Use stricter regexp for the variable name, if requested.
184191
# Greedy matching may cause issues with the generic '.*'
@@ -415,9 +422,11 @@ def main():
415422
pad_depth = label_length if label_length < 21 else 4
416423
output_line += " " * pad_depth
417424

418-
# Process the rest of the line.
425+
# Process the rest of the line. Use the original SSA name to generate the LIT
426+
# variable names.
427+
use_ssa_names = True
419428
output_line += process_line(
420-
[argument], variable_namer, args.strict_name_re
429+
[argument], variable_namer, use_ssa_names, args.strict_name_re
421430
)
422431

423432
# Append the output line.

0 commit comments

Comments
 (0)