-
Notifications
You must be signed in to change notification settings - Fork 522
/
Copy pathmain.cpp
90 lines (69 loc) · 2.87 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gflags/gflags.h>
#include <executorch/examples/models/llama/runner/runner.h>
#if defined(ET_USE_THREADPOOL)
#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <executorch/extension/threadpool/threadpool.h>
#endif
DEFINE_string(
model_path,
"llama2.pte",
"Model serialized in flatbuffer format.");
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
DEFINE_double(
temperature,
0.8f,
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
DEFINE_int32(
seq_len,
128,
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
DEFINE_int32(
cpu_threads,
-1,
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
DEFINE_bool(warmup, false, "Whether to run a warmup run.");
int32_t main(int32_t argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
// Create a loader to get the data of the program file. There are other
// DataLoaders that use mmap() or point32_t to data that's already in memory,
// and users can create their own DataLoaders to load from arbitrary sources.
const char* model_path = FLAGS_model_path.c_str();
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
const char* prompt = FLAGS_prompt.c_str();
float temperature = FLAGS_temperature;
int32_t seq_len = FLAGS_seq_len;
int32_t cpu_threads = FLAGS_cpu_threads;
bool warmup = FLAGS_warmup;
#if defined(ET_USE_THREADPOOL)
uint32_t num_performant_cores = cpu_threads == -1
? ::executorch::extension::cpuinfo::get_num_performant_cores()
: static_cast<uint32_t>(cpu_threads);
ET_LOG(
Info, "Resetting threadpool with num threads = %d", num_performant_cores);
if (num_performant_cores > 0) {
::executorch::extension::threadpool::get_threadpool()
->_unsafe_reset_threadpool(num_performant_cores);
}
#endif
// create llama runner
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
example::Runner runner(model_path, tokenizer_path);
if (warmup) {
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
}
// generate
executorch::extension::llm::GenerationConfig config{
.seq_len = seq_len, .temperature = temperature};
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
runner.generate(prompt, config);
return 0;
}