Skip to content

Commit 262dfc0

Browse files
dbortfacebook-github-bot
authored andcommitted
Move examples/models/... out of the torch namespace (#5318)
Summary: Pull Request resolved: #5318 The code under examples/... is a proxy for user code, and users should never declare code under the `torch::` or `executorch::` namespaces. Move this code under the `example::` namespace to make it more clear that users should use their own namespaces when writing code like this. I made one non-mechanical change in llama_tiktoken.h: we should always use `enum class` to avoid polluting the parent namespace, and enum values should follow UpperSnakeCase. Tests: Should be a no-op; CI passes. Prints no output: ``` find examples/models -type f | xargs grep "namespace torch" find examples/models -type f | xargs grep "namespace executorch" ``` Build llava_main: ``` bash .ci/scripts/test_llava.sh ``` Reviewed By: larryliu0820 Differential Revision: D62591348 fbshipit-source-id: 8da987245bf777ced16000f681b54652939cef47
1 parent 08ecd73 commit 262dfc0

File tree

20 files changed

+162
-136
lines changed

20 files changed

+162
-136
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ @interface LLaMARunner ()<ExecuTorchLogSink>
2121
@end
2222

2323
@implementation LLaMARunner {
24-
std::unique_ptr<Runner> _runner;
24+
std::unique_ptr<example::Runner> _runner;
2525
}
2626

2727
- (instancetype)initWithModelPath:(NSString*)modelPath
2828
tokenizerPath:(NSString*)tokenizerPath {
2929
self = [super init];
3030
if (self) {
3131
[ExecuTorchLog.sharedLog addSink:self];
32-
_runner = std::make_unique<Runner>(
32+
_runner = std::make_unique<example::Runner>(
3333
modelPath.UTF8String, tokenizerPath.UTF8String);
3434
}
3535
return self;
@@ -109,15 +109,15 @@ @interface LLaVARunner ()<ExecuTorchLogSink>
109109
@end
110110

111111
@implementation LLaVARunner {
112-
std::unique_ptr<LlavaRunner> _runner;
112+
std::unique_ptr<example::LlavaRunner> _runner;
113113
}
114114

115115
- (instancetype)initWithModelPath:(NSString*)modelPath
116116
tokenizerPath:(NSString*)tokenizerPath {
117117
self = [super init];
118118
if (self) {
119119
[ExecuTorchLog.sharedLog addSink:self];
120-
_runner = std::make_unique<LlavaRunner>(
120+
_runner = std::make_unique<example::LlavaRunner>(
121121
modelPath.UTF8String, tokenizerPath.UTF8String);
122122
}
123123
return self;

examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ std::unique_ptr<Tokenizer> load_tokenizer() {
316316
if (FLAGS_tokenizer_type == "bpe") {
317317
tokenizer = std::make_unique<torch::executor::BPETokenizer>();
318318
} else if (FLAGS_tokenizer_type == "tiktoken") {
319-
tokenizer = torch::executor::get_tiktoken_for_llama();
319+
tokenizer = example::get_tiktoken_for_llama();
320320
}
321321
ET_CHECK_MSG(
322322
tokenizer, "Invalid tokenizer type: %s", FLAGS_tokenizer_type.c_str());

examples/models/flamingo/cross_attention/cross_attention_mask.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
#include <algorithm>
1212
#include <string>
1313

14-
namespace torch::executor {
14+
namespace example {
15+
16+
using ::executorch::aten::ScalarType;
17+
using ::executorch::aten::Tensor;
18+
using ::executorch::aten::TensorImpl;
1519

1620
// Fowrward declaration needed for ARM compilers.
1721
int32_t safe_size_t_to_sizes_type(size_t value);
@@ -166,4 +170,4 @@ std::vector<executorch::extension::TensorPtr> cross_attention_mask(
166170
return cross_attention_masks;
167171
}
168172

169-
} // namespace torch::executor
173+
} // namespace example

examples/models/flamingo/cross_attention/cross_attention_mask.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
#include <executorch/extension/tensor/tensor.h>
1414
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1515

16-
namespace torch {
17-
namespace executor {
16+
namespace example {
1817

1918
/**
2019
* Computes the cross-attention mask for text + image inputs. Text tokens that
@@ -61,11 +60,10 @@ namespace executor {
6160
*/
6261
std::vector<::executorch::extension::TensorPtr> cross_attention_mask(
6362
const std::vector<int>& tokens,
64-
const std::vector<Tensor>& images,
63+
const std::vector<::executorch::aten::Tensor>& images,
6564
size_t tile_size,
6665
size_t patch_size,
6766
int image_token_id,
6867
std::vector<std::vector<int>>& out);
6968

70-
} // namespace executor
71-
} // namespace torch
69+
} // namespace example

examples/models/flamingo/cross_attention/cross_attention_mask_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ TEST(CrossAttentxnMaskTest, TestCrossAttentionMask) {
4141

4242
std::vector<Tensor> images = {a, b, c};
4343
std::vector<std::vector<int>> mask_data;
44-
auto output_masks = torch::executor::cross_attention_mask(
44+
auto output_masks = example::cross_attention_mask(
4545
tokens,
4646
images,
4747
/*tile_size=*/1,

examples/models/llama2/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ int32_t main(int32_t argc, char** argv) {
6969
}
7070
#endif
7171
// create llama runner
72-
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
72+
example::Runner runner(model_path, tokenizer_path, temperature);
7373

7474
// generate
7575
runner.generate(prompt, seq_len);

examples/models/llama2/runner/runner.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1919
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
2020

21-
namespace torch::executor {
21+
namespace example {
22+
23+
using ::executorch::extension::Module;
24+
using ::executorch::runtime::Error;
25+
using ::executorch::runtime::Result;
26+
27+
namespace llm = ::executorch::extension::llm;
28+
2229
namespace {
2330
static constexpr auto kAppendEosToPrompt = "append_eos_to_prompt";
2431
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
@@ -80,7 +87,7 @@ Error Runner::load() {
8087
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
8188
tokenizer_path_.c_str());
8289
tokenizer_.reset();
83-
tokenizer_ = std::make_unique<BPETokenizer>();
90+
tokenizer_ = std::make_unique<llm::BPETokenizer>();
8491
tokenizer_->load(tokenizer_path_);
8592
}
8693

@@ -119,17 +126,17 @@ Error Runner::load() {
119126
ET_LOG(Info, "eos_id = %" PRId64, value);
120127
}
121128
}
122-
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
129+
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
123130
module_.get(),
124131
metadata_.at(kUseKVCache),
125132
metadata_.at(kVocabSize),
126133
temperature_);
127-
text_prefiller_ = std::make_unique<TextPrefiller>(
134+
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
128135
text_decoder_runner_.get(),
129136
metadata_.at(kUseKVCache),
130137
metadata_.at(kEnableDynamicShape));
131138

132-
text_token_generator_ = std::make_unique<TextTokenGenerator>(
139+
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
133140
tokenizer_.get(),
134141
text_decoder_runner_.get(),
135142
metadata_.at(kUseKVCache),
@@ -143,26 +150,26 @@ Error Runner::generate(
143150
const std::string& prompt,
144151
int32_t seq_len,
145152
std::function<void(const std::string&)> token_callback,
146-
std::function<void(const Stats&)> stats_callback,
153+
std::function<void(const llm::Stats&)> stats_callback,
147154
bool echo) {
148155
// Prepare the inputs.
149156
// Use ones-initialized inputs.
150157
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
151158
if (!is_loaded()) {
152-
stats_.model_load_start_ms = util::time_in_ms();
159+
stats_.model_load_start_ms = llm::time_in_ms();
153160
ET_CHECK_OK_OR_RETURN_ERROR(load());
154-
stats_.model_load_end_ms = util::time_in_ms();
161+
stats_.model_load_end_ms = llm::time_in_ms();
155162
}
156163

157164
ET_LOG(
158165
Info,
159166
"RSS after loading model: %f MiB (0 if unsupported)",
160-
util::get_rss_bytes() / 1024.0 / 1024.0);
167+
llm::get_rss_bytes() / 1024.0 / 1024.0);
161168

162169
// Wrap the token_callback with print function
163170
std::function<void(const std::string&)> wrapped_callback =
164171
[token_callback](const std::string& piece) {
165-
util::safe_printf(piece.c_str());
172+
llm::safe_printf(piece.c_str());
166173
fflush(stdout);
167174
if (token_callback) {
168175
token_callback(piece);
@@ -171,7 +178,7 @@ Error Runner::generate(
171178
// First token time only measures the time it takes to encode the prompt and
172179
// return a response token.
173180

174-
stats_.inference_start_ms = util::time_in_ms();
181+
stats_.inference_start_ms = llm::time_in_ms();
175182
shouldStop_ = false;
176183

177184
// Set the sequence length to the max seq length if not provided
@@ -214,8 +221,8 @@ Error Runner::generate(
214221
}
215222
int64_t pos = 0;
216223
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
217-
stats_.first_token_ms = util::time_in_ms();
218-
stats_.prompt_eval_end_ms = util::time_in_ms();
224+
stats_.first_token_ms = llm::time_in_ms();
225+
stats_.prompt_eval_end_ms = llm::time_in_ms();
219226
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
220227
uint64_t cur_token = prefill_res.get();
221228

@@ -224,19 +231,19 @@ Error Runner::generate(
224231
ET_LOG(
225232
Info,
226233
"RSS after prompt prefill: %f MiB (0 if unsupported)",
227-
util::get_rss_bytes() / 1024.0 / 1024.0);
234+
llm::get_rss_bytes() / 1024.0 / 1024.0);
228235

229236
// start the main loop
230237
prompt_tokens.push_back(cur_token);
231238
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
232239
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
233240

234-
stats_.inference_end_ms = util::time_in_ms();
241+
stats_.inference_end_ms = llm::time_in_ms();
235242
printf("\n");
236243
ET_LOG(
237244
Info,
238245
"RSS after finishing text generation: %f MiB (0 if unsupported)",
239-
util::get_rss_bytes() / 1024.0 / 1024.0);
246+
llm::get_rss_bytes() / 1024.0 / 1024.0);
240247

241248
if (num_prompt_tokens + num_generated_tokens == seq_len) {
242249
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
@@ -259,4 +266,4 @@ void Runner::stop() {
259266
ET_LOG(Error, "Token generator is not loaded, cannot stop");
260267
}
261268
}
262-
} // namespace torch::executor
269+
} // namespace example

examples/models/llama2/runner/runner.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
#include <executorch/extension/llm/tokenizer/tokenizer.h>
2525
#include <executorch/extension/module/module.h>
2626

27-
namespace torch::executor {
28-
using Stats = ::executorch::llm::Stats;
27+
namespace example {
2928

3029
class Runner {
3130
public:
@@ -35,12 +34,13 @@ class Runner {
3534
const float temperature = 0.8f);
3635

3736
bool is_loaded() const;
38-
Error load();
39-
Error generate(
37+
::executorch::runtime::Error load();
38+
::executorch::runtime::Error generate(
4039
const std::string& prompt,
4140
int32_t seq_len = 128,
4241
std::function<void(const std::string&)> token_callback = {},
43-
std::function<void(const Stats&)> stats_callback = {},
42+
std::function<void(const ::executorch::extension::llm::Stats&)>
43+
stats_callback = {},
4444
bool echo = true);
4545
void stop();
4646

@@ -49,16 +49,18 @@ class Runner {
4949
bool shouldStop_{false};
5050

5151
// model
52-
std::unique_ptr<Module> module_;
52+
std::unique_ptr<::executorch::extension::Module> module_;
5353
std::string tokenizer_path_;
54-
std::unique_ptr<Tokenizer> tokenizer_;
54+
std::unique_ptr<::executorch::extension::llm::Tokenizer> tokenizer_;
5555
std::unordered_map<std::string, int64_t> metadata_;
56-
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
57-
std::unique_ptr<TextPrefiller> text_prefiller_;
58-
std::unique_ptr<TextTokenGenerator> text_token_generator_;
56+
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
57+
text_decoder_runner_;
58+
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
59+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
60+
text_token_generator_;
5961

6062
// stats
61-
Stats stats_;
63+
::executorch::extension::llm::Stats stats_;
6264
};
6365

64-
} // namespace torch::executor
66+
} // namespace example

examples/models/llama2/tokenizer/llama_tiktoken.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
1010

11-
namespace torch {
12-
namespace executor {
11+
namespace example {
12+
13+
using ::executorch::extension::llm::Tiktoken;
14+
1315
namespace {
1416
static constexpr int32_t kSpecialTokensSize = 256;
1517
static constexpr size_t kBOSTokenIndex = 0;
@@ -72,7 +74,7 @@ _get_multimodal_special_tokens() {
7274

7375
std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
7476
switch (version) {
75-
case MULTIMODAL:
77+
case Version::Multimodal:
7678
return _get_multimodal_special_tokens();
7779
default:
7880
return _get_default_special_tokens();
@@ -86,5 +88,4 @@ std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
8688
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
8789
}
8890

89-
} // namespace executor
90-
} // namespace torch
91+
} // namespace example

examples/models/llama2/tokenizer/llama_tiktoken.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010

1111
#include <executorch/extension/llm/tokenizer/tiktoken.h>
1212

13-
namespace torch {
14-
namespace executor {
13+
namespace example {
1514

16-
enum Version {
17-
DEFAULT,
18-
MULTIMODAL,
15+
enum class Version {
16+
Default,
17+
Multimodal,
1918
};
2019

21-
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version = DEFAULT);
20+
std::unique_ptr<::executorch::extension::llm::Tiktoken> get_tiktoken_for_llama(
21+
Version version = Version::Default);
2222

23-
} // namespace executor
24-
} // namespace torch
23+
} // namespace example

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@
77
*/
88

99
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
10+
11+
#include <vector>
12+
1013
#include <executorch/runtime/platform/runtime.h>
14+
1115
#include <gtest/gtest.h>
12-
#include <vector>
1316

1417
using namespace ::testing;
1518

16-
namespace torch {
17-
namespace executor {
19+
using ::example::Version;
20+
using ::executorch::extension::llm::Tokenizer;
21+
using ::executorch::runtime::Error;
22+
using ::executorch::runtime::Result;
1823

1924
class MultimodalTiktokenV5ExtensionTest : public Test {
2025
public:
2126
void SetUp() override {
22-
torch::executor::runtime_init();
23-
tokenizer_ = get_tiktoken_for_llama(MULTIMODAL);
27+
executorch::runtime::runtime_init();
28+
tokenizer_ = get_tiktoken_for_llama(Version::Multimodal);
2429
modelPath_ = std::getenv("RESOURCES_PATH") +
2530
std::string("/test_tiktoken_tokenizer.model");
2631
}
@@ -79,5 +84,3 @@ TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
7984
EXPECT_EQ(out.get(), expected[i]);
8085
}
8186
}
82-
} // namespace executor
83-
} // namespace torch

examples/models/llava/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ int32_t main(int32_t argc, char** argv) {
8080
}
8181
#endif
8282
// create llama runner
83-
torch::executor::LlavaRunner runner(model_path, tokenizer_path, temperature);
83+
example::LlavaRunner runner(model_path, tokenizer_path, temperature);
8484

8585
// read image and resize the longest edge to 336
8686
std::vector<uint8_t> image_data;

0 commit comments

Comments
 (0)