Skip to content

Commit 711282e

Browse files
committed
Update base for Update on "[ET-VK][ez] Make squeeze insertion requirements more strict"
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned]
2 parents f1e2f1a + 6adff9c commit 711282e

39 files changed

+2195
-740
lines changed

Diff for: .ci/docker/ci_commit_pins/pytorch.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7ae0ce6360b6e4f944906502d20da24c04debee5
1+
59d5cf083b4f860dea76fe8936076177f9367f10

Diff for: backends/arm/test/models/test_conformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TestConformer(unittest.TestCase):
3131
# .to_executorch step, i.e. after Arm partitioner.
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
34-
"torch.ops.aten._assert_scalar.default": 10,
34+
"torch.ops.aten._assert_scalar.default": 7,
3535
"torch.ops.aten._local_scalar_dense.default": 1,
3636
}
3737

Diff for: backends/arm/test/models/test_llama.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sys
1212
import unittest
1313

14+
import pytest
1415
import torch
1516

1617
from executorch.backends.arm.test import common, conftest
@@ -102,7 +103,7 @@ def test_llama_tosa_MI(self):
102103
llama_model, llama_inputs, llama_meta = self.prepare_model()
103104

104105
if llama_model is None and llama_inputs is None and llama_meta is None:
105-
return
106+
pytest.skip("Missing model and/or input files")
106107

107108
with torch.no_grad():
108109
(

Diff for: backends/xnnpack/operators/op_slice_copy.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def define_node(
6969
output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
7070
dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]
7171

72-
slice_begin_index = cast(int, node.args[2])
72+
slice_begin_index = 0
73+
if len(node.args) > 2 and node.args[2]:
74+
slice_begin_index = cast(int, node.args[2])
7375
if slice_begin_index < 0:
7476
slice_begin_index = input_shape[dim_of_slice] + slice_begin_index
7577

Diff for: backends/xnnpack/test/ops/test_slice_copy.py

+12
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def forward(self, x):
6969
# Note that two of the slices are optimized away as they are identity.
7070
self._test_slice_copy(ConvSlice(), inputs, 4, 2)
7171

72+
def test_fp32_slice_copy_default_start(self):
73+
"""
74+
XNNPACK supports default start in slice op.
75+
"""
76+
77+
class Slice(torch.nn.Module):
78+
def forward(self, x):
79+
return torch.ops.aten.slice.Tensor(x, 0, None, 2)
80+
81+
inputs = (torch.randn(5, 5),)
82+
self._test_slice_copy(Slice(), inputs, 1, 1)
83+
7284
def test_fp32_slice_copy_stride_non_1(self):
7385
"""
7486
XNNPACK does not support strided slicing.

Diff for: devtools/etdump/etdump_filter.cpp

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/devtools/etdump/etdump_filter.h>
10+
11+
#include <executorch/runtime/core/error.h>
12+
13+
using ::executorch::runtime::DelegateDebugIntId;
14+
using ::executorch::runtime::Error;
15+
using ::executorch::runtime::kUnsetDelegateDebugIntId;
16+
17+
namespace executorch {
18+
namespace etdump {
19+
20+
ETDumpFilter::ETDumpFilter() = default;
21+
22+
Result<bool> ETDumpFilter::add_regex(string_view pattern) {
23+
auto regex = std::make_unique<re2::RE2>(pattern.data());
24+
if (!regex->ok()) {
25+
return Error::InvalidArgument; // Error during regex compilation
26+
}
27+
regex_patterns_.emplace_back(std::move(regex));
28+
return true;
29+
}
30+
31+
Result<bool> ETDumpFilter::set_debug_handle_range(size_t start, size_t end) {
32+
if (start >= end) {
33+
return Error::InvalidArgument; // Start is greater than end
34+
}
35+
if (start < 0 || end < 0) {
36+
return Error::InvalidArgument; // Start or end is negative
37+
}
38+
range_start_ = start;
39+
range_end_ = end;
40+
return true;
41+
}
42+
43+
Result<bool> ETDumpFilter::filter_name_(const char* name) {
44+
if (name == nullptr) {
45+
return Error::InvalidArgument;
46+
}
47+
if (regex_patterns_.empty()) {
48+
return true;
49+
}
50+
for (const auto& regex : regex_patterns_) {
51+
if (RE2::FullMatch(name, *regex)) {
52+
return true;
53+
}
54+
}
55+
return false;
56+
}
57+
58+
Result<bool> ETDumpFilter::filter_delegate_debug_index_(
59+
DelegateDebugIntId debug_handle) {
60+
if (debug_handle == kUnsetDelegateDebugIntId) {
61+
return Error::InvalidArgument; // Delegate debug index is unset
62+
}
63+
64+
if (range_start_ == 0 && range_end_ == 0) {
65+
return true;
66+
}
67+
68+
if (debug_handle < range_start_ || debug_handle >= range_end_) {
69+
return false;
70+
}
71+
72+
return true;
73+
}
74+
75+
Result<bool> ETDumpFilter::filter(
76+
const char* name,
77+
DelegateDebugIntId delegate_debug_index) {
78+
if ((name == nullptr) == (delegate_debug_index == kUnsetDelegateDebugIntId)) {
79+
return Error::InvalidArgument; // Name and delegate debug index should be
80+
// both set or unset
81+
}
82+
83+
if (name) {
84+
return filter_name_(name);
85+
} else {
86+
return filter_delegate_debug_index_(delegate_debug_index);
87+
}
88+
}
89+
90+
size_t ETDumpFilter::get_n_regex() const {
91+
return regex_patterns_.size();
92+
}
93+
94+
} // namespace etdump
95+
} // namespace executorch

Diff for: devtools/etdump/etdump_filter.h

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <re2/re2.h>
12+
#include <memory>
13+
14+
#include <executorch/runtime/core/event_tracer.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/platform/platform.h>
17+
18+
namespace executorch::etdump {
19+
20+
using ::executorch::aten::string_view;
21+
using ::executorch::runtime::Result;
22+
23+
/**
24+
* ETDumpFilter is a class that filters intermediate output based on output's
25+
* name by full regex filtering, or delegate debug indices by range-based
26+
* filtering.
27+
*
28+
* Note that this filter supports up to MAX_REGEX_PATTERNS regex patterns with a
29+
* maximum length of MAX_PATTERN_LENGTH characters each.
30+
*/
31+
32+
class ETDumpFilter : public ::executorch::runtime::EventTracerFilterBase {
33+
public:
34+
ETDumpFilter();
35+
~ETDumpFilter() override = default;
36+
/**
37+
* Adds a regex pattern to the filter.
38+
*
39+
* @param[in] pattern A c string representing the regex pattern to be added.
40+
*
41+
* @return A Result<bool> indicating the success or failure of adding the
42+
* regex pattern.
43+
* - True if the pattern is successfully added.
44+
* - False if the pattern could not be added or if the maximum number
45+
* of patterns is exceeded.
46+
* - An error code if number of pattern has reached to cap, or any
47+
* error occurs during regex compilation.
48+
*/
49+
Result<bool> add_regex(string_view pattern);
50+
/**
51+
* Sets the range for the delegate debug index filtering as [start, end).
52+
* Note that this function will flush the existing range.
53+
*
54+
* @param[in] start The start of the range for filtering.
55+
* @param[in] end The end of the range for filtering.
56+
*
57+
* @return A Result<bool> indicating the success or failure of setting the
58+
* range.
59+
* - True if the range is successfully set.
60+
* - An error code if an error occurs.
61+
*/
62+
Result<bool> set_debug_handle_range(size_t start, size_t end);
63+
64+
/**
65+
* Filters events based on the given name or delegate debug index.
66+
*
67+
* Note that everytime only one of either the name or delegate_debug_index
68+
* should be passed in.
69+
*
70+
* @param[in] name A pointer to a string representing the `name` of the
71+
* event. If `delegate_debug_index` is not set to kUnsetDebugHandle, `name`
72+
* should be set to nullptr.
73+
*
74+
* @param[in] delegate_debug_index A DebugHandle representing the debug index
75+
* of the delegate. If `name` is not nullptr, this should be set to
76+
* kUnsetDebugHandle.
77+
*
78+
* @return A Result<bool> indicating whether the event matches the filter
79+
* criteria.
80+
* - True if the event matches the filter, or filter is unset.
81+
* - False if the event does not match or is unknown.
82+
* - An error code if an error occurs during filtering.
83+
*/
84+
Result<bool> filter(
85+
const char* name,
86+
::executorch::runtime::DelegateDebugIntId delegate_debug_index) override;
87+
88+
/**
89+
* Returns the number of regex patterns in the filter.
90+
*/
91+
size_t get_n_regex() const;
92+
93+
private:
94+
std::vector<std::unique_ptr<re2::RE2>> regex_patterns_;
95+
size_t range_start_ = 0;
96+
size_t range_end_ = 0;
97+
Result<bool> filter_name_(const char* name);
98+
Result<bool> filter_delegate_debug_index_(
99+
::executorch::runtime::DelegateDebugIntId delegate_debug_index);
100+
};
101+
102+
} // namespace executorch::etdump

Diff for: devtools/etdump/etdump_flatcc.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <executorch/devtools/etdump/etdump_schema_flatcc_builder.h>
1616
#include <executorch/devtools/etdump/etdump_schema_flatcc_reader.h>
1717
#include <executorch/devtools/etdump/utils.h>
18+
#include <executorch/runtime/core/error.h>
1819
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1920
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2021
#include <executorch/runtime/platform/assert.h>
@@ -28,6 +29,7 @@ using ::executorch::runtime::ChainID;
2829
using ::executorch::runtime::DebugHandle;
2930
using ::executorch::runtime::DelegateDebugIdType;
3031
using ::executorch::runtime::DelegateDebugIntId;
32+
using ::executorch::runtime::Error;
3133
using ::executorch::runtime::EValue;
3234
using ::executorch::runtime::EventTracerEntry;
3335
using ::executorch::runtime::kUnsetDelegateDebugIntId;

Diff for: devtools/etdump/etdump_flatcc.h

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#pragma once
1010

1111
#include <cstdint>
12-
#include <memory>
1312

1413
#include <executorch/devtools/etdump/data_sinks/buffer_data_sink.h>
1514
#include <executorch/devtools/etdump/data_sinks/data_sink_base.h>

Diff for: devtools/etdump/targets.bzl

+21
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,27 @@ def define_common_targets():
101101
for aten_mode in get_aten_mode_options():
102102
aten_suffix = "_aten" if aten_mode else ""
103103

104+
runtime.cxx_library(
105+
name = "etdump_filter" + aten_suffix,
106+
srcs = [
107+
"etdump_filter.cpp",
108+
],
109+
exported_headers = [
110+
"etdump_filter.h",
111+
],
112+
deps = [
113+
"//executorch/runtime/platform:platform",
114+
],
115+
exported_deps = [
116+
"fbsource//third-party/re2:re2",
117+
"//executorch/runtime/core:event_tracer" + aten_suffix,
118+
],
119+
visibility = [
120+
"//executorch/...",
121+
"@EXECUTORCH_CLIENTS",
122+
],
123+
)
124+
104125
runtime.cxx_library(
105126
name = "etdump_flatcc" + aten_suffix,
106127
srcs = [

0 commit comments

Comments
 (0)