Skip to content

Commit b4ec471

Browse files
ggerganovtinglou
authored andcommitted
tts : add OuteTTS support (ggml-org#10784)
* server : add "tokens" output ggml-ci * server : output embeddings for all tokens when pooling = none ggml-ci * server : be explicit about the pooling type in the tests ggml-ci * server : do not normalize embeddings when there is no pooling ggml-ci * llama : add OuteTTS support (wip) * wip * extract features * first conv * group norm * resnet conv * resnet * attn * pos net * layer norm * convnext * head * hann window * fix n_embd + remove llama.cpp hacks * compute hann window * fft * spectrum processing * clean-up * tts : receive input text and generate codes * clip : fix new conv name * tts : minor fix * tts : add header + minor fixes ggml-ci * tts : add matchematical constant ggml-ci * tts : fix sampling + cut initial noise * tts : fixes * tts : update default samplers ggml-ci * tts : text pre-processing * tts : outetts-voc -> wavtokenizer-dec * tts : remove hardcoded constants ggml-ci * tts : fix tensor shapes * llama : refactor wavtokenizer tensors ggml-ci * cont ggml-ci * cont [no ci] * llama : update WavTokenizer to non-causal attn * llama : handle no-vocab detokenization * tts : add Python example for OuteTTS (wip) * tts : extend python example to generate spectrogram ggml-ci * server : fix rebase artifacts * tts : enable "return_tokens" in Python example ggml-ci * tts : minor fixes * common : support HF download for vocoder
1 parent 6dcee72 commit b4ec471

19 files changed

+2507
-530
lines changed

common/arg.cpp

+44-16
Original file line numberDiff line numberDiff line change
@@ -119,29 +119,33 @@ std::string common_arg::to_string() {
119119
// utils
120120
//
121121

122-
static void common_params_handle_model_default(common_params & params) {
123-
if (!params.hf_repo.empty()) {
122+
static void common_params_handle_model_default(
123+
std::string & model,
124+
std::string & model_url,
125+
std::string & hf_repo,
126+
std::string & hf_file) {
127+
if (!hf_repo.empty()) {
124128
// short-hand to avoid specifying --hf-file -> default it to --model
125-
if (params.hf_file.empty()) {
126-
if (params.model.empty()) {
129+
if (hf_file.empty()) {
130+
if (model.empty()) {
127131
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
128132
}
129-
params.hf_file = params.model;
130-
} else if (params.model.empty()) {
133+
hf_file = model;
134+
} else if (model.empty()) {
131135
// this is to avoid different repo having same file name, or same file name in different subdirs
132-
std::string filename = params.hf_repo + "_" + params.hf_file;
136+
std::string filename = hf_repo + "_" + hf_file;
133137
// to make sure we don't have any slashes in the filename
134138
string_replace_all(filename, "/", "_");
135-
params.model = fs_get_cache_file(filename);
139+
model = fs_get_cache_file(filename);
136140
}
137-
} else if (!params.model_url.empty()) {
138-
if (params.model.empty()) {
139-
auto f = string_split<std::string>(params.model_url, '#').front();
141+
} else if (!model_url.empty()) {
142+
if (model.empty()) {
143+
auto f = string_split<std::string>(model_url, '#').front();
140144
f = string_split<std::string>(f, '?').front();
141-
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
145+
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
142146
}
143-
} else if (params.model.empty()) {
144-
params.model = DEFAULT_MODEL_PATH;
147+
} else if (model.empty()) {
148+
model = DEFAULT_MODEL_PATH;
145149
}
146150
}
147151

@@ -276,7 +280,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
276280
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
277281
}
278282

279-
common_params_handle_model_default(params);
283+
// TODO: refactor model params in a common struct
284+
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
285+
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
280286

281287
if (params.escape) {
282288
string_process_escapes(params.prompt);
@@ -842,7 +848,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
842848
}
843849
).set_sparam());
844850
add_opt(common_arg(
845-
{"--sampling-seq"}, "SEQUENCE",
851+
{"--sampling-seq", "--sampler-seq"}, "SEQUENCE",
846852
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
847853
[](common_params & params, const std::string & value) {
848854
params.sampling.samplers = common_sampler_types_from_chars(value);
@@ -1581,6 +1587,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15811587
params.hf_file = value;
15821588
}
15831589
).set_env("LLAMA_ARG_HF_FILE"));
1590+
add_opt(common_arg(
1591+
{"-hfrv", "--hf-repo-v"}, "REPO",
1592+
"Hugging Face model repository for the vocoder model (default: unused)",
1593+
[](common_params & params, const std::string & value) {
1594+
params.vocoder.hf_repo = value;
1595+
}
1596+
).set_env("LLAMA_ARG_HF_REPO_V"));
1597+
add_opt(common_arg(
1598+
{"-hffv", "--hf-file-v"}, "FILE",
1599+
"Hugging Face model file for the vocoder model (default: unused)",
1600+
[](common_params & params, const std::string & value) {
1601+
params.vocoder.hf_file = value;
1602+
}
1603+
).set_env("LLAMA_ARG_HF_FILE_V"));
15841604
add_opt(common_arg(
15851605
{"-hft", "--hf-token"}, "TOKEN",
15861606
"Hugging Face access token (default: value from HF_TOKEN environment variable)",
@@ -2178,5 +2198,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21782198
}
21792199
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
21802200

2201+
add_opt(common_arg(
2202+
{"-mv", "--model-vocoder"}, "FNAME",
2203+
"vocoder model for audio generation (default: unused)",
2204+
[](common_params & params, const std::string & value) {
2205+
params.vocoder.model = value;
2206+
}
2207+
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
2208+
21812209
return ctx_arg;
21822210
}

common/common.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
10951095
#define CURL_MAX_RETRY 3
10961096
#define CURL_RETRY_DELAY_SECONDS 2
10971097

1098-
static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
1098+
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
10991099
int remaining_attempts = max_attempts;
11001100

11011101
while (remaining_attempts > 0) {
@@ -1119,7 +1119,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
11191119
}
11201120

11211121
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
1122-
11231122
// Initialize libcurl
11241123
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
11251124
if (!curl) {
@@ -1192,11 +1191,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
11921191
std::string etag;
11931192
std::string last_modified;
11941193
};
1194+
11951195
common_load_model_from_url_headers headers;
1196+
11961197
{
11971198
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
11981199
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
1199-
common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
1200+
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
12001201

12011202
static std::regex header_regex("([^:]+): (.*)\r\n");
12021203
static std::regex etag_regex("ETag", std::regex_constants::icase);

common/common.h

+12-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum llama_example {
8080
LLAMA_EXAMPLE_LLAVA,
8181
LLAMA_EXAMPLE_LOOKUP,
8282
LLAMA_EXAMPLE_PARALLEL,
83+
LLAMA_EXAMPLE_TTS,
8384

8485
LLAMA_EXAMPLE_COUNT,
8586
};
@@ -159,6 +160,7 @@ struct common_params_sampling {
159160

160161
struct common_params_speculative {
161162
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
163+
162164
int32_t n_ctx = 0; // draft context size
163165
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
164166
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
@@ -172,6 +174,14 @@ struct common_params_speculative {
172174
std::string model = ""; // draft model for speculative decoding // NOLINT
173175
};
174176

177+
struct common_params_vocoder {
178+
std::string hf_repo = ""; // HF repo // NOLINT
179+
std::string hf_file = ""; // HF file // NOLINT
180+
181+
std::string model = ""; // model path // NOLINT
182+
std::string model_url = ""; // model url to download // NOLINT
183+
};
184+
175185
struct common_params {
176186
int32_t n_predict = -1; // new tokens to predict
177187
int32_t n_ctx = 4096; // context size
@@ -214,8 +224,9 @@ struct common_params {
214224
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
215225
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
216226

217-
struct common_params_sampling sampling;
227+
struct common_params_sampling sampling;
218228
struct common_params_speculative speculative;
229+
struct common_params_vocoder vocoder;
219230

220231
std::string model = ""; // model path // NOLINT
221232
std::string model_alias = ""; // model alias // NOLINT

convert_hf_to_gguf.py

+52-7
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,17 @@ def set_gguf_parameters(self):
221221
self.gguf_writer.add_context_length(n_ctx)
222222
logger.info(f"gguf: context length = {n_ctx}")
223223

224-
n_embd = self.find_hparam(["hidden_size", "n_embd"])
225-
self.gguf_writer.add_embedding_length(n_embd)
226-
logger.info(f"gguf: embedding length = {n_embd}")
224+
if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
225+
self.gguf_writer.add_embedding_length(n_embd)
226+
logger.info(f"gguf: embedding length = {n_embd}")
227227

228228
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
229229
self.gguf_writer.add_feed_forward_length(n_ff)
230230
logger.info(f"gguf: feed forward length = {n_ff}")
231231

232-
n_head = self.find_hparam(["num_attention_heads", "n_head"])
233-
self.gguf_writer.add_head_count(n_head)
234-
logger.info(f"gguf: head count = {n_head}")
232+
if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
233+
self.gguf_writer.add_head_count(n_head)
234+
logger.info(f"gguf: head count = {n_head}")
235235

236236
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
237237
self.gguf_writer.add_head_count_kv(n_head_kv)
@@ -296,7 +296,9 @@ def prepare_tensors(self):
296296
break
297297

298298
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
299-
data = data_torch.squeeze().numpy()
299+
# TODO: why do we squeeze here?
300+
# data = data_torch.squeeze().numpy()
301+
data = data_torch.numpy()
300302

301303
# if data ends up empty, it means data_torch was a scalar tensor -> restore
302304
if len(data.shape) == 0:
@@ -324,6 +326,8 @@ def prepare_tensors(self):
324326
gguf.MODEL_TENSOR.TIME_MIX_W2,
325327
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
326328
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
329+
gguf.MODEL_TENSOR.POSNET_NORM1,
330+
gguf.MODEL_TENSOR.POSNET_NORM2,
327331
)
328332
)
329333
or not new_name.endswith(".weight")
@@ -689,6 +693,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
689693
return res
690694
# Marker: End get_vocab_base_pre
691695

696+
def _set_vocab_none(self) -> None:
697+
self.gguf_writer.add_tokenizer_model("none")
698+
692699
def _set_vocab_gpt2(self) -> None:
693700
tokens, toktypes, tokpre = self.get_vocab_base()
694701
self.gguf_writer.add_tokenizer_model("gpt2")
@@ -2027,6 +2034,44 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
20272034
yield name, data
20282035

20292036

2037+
@Model.register("WavTokenizerDec")
2038+
class WavTokenizerDecModel(Model):
2039+
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
2040+
2041+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2042+
del bid # unused
2043+
2044+
if \
2045+
name.endswith("codebook.cluster_size") or \
2046+
name.endswith("codebook.embed_avg") or \
2047+
name.endswith("codebook.inited"):
2048+
logger.debug(f"Skipping {name!r}")
2049+
return []
2050+
2051+
logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}")
2052+
2053+
return [(self.map_tensor_name(name), data_torch)]
2054+
2055+
def set_vocab(self):
2056+
self._set_vocab_none()
2057+
2058+
def set_gguf_parameters(self):
2059+
super().set_gguf_parameters()
2060+
self.gguf_writer.add_vocab_size (self.hparams["vocab_size"])
2061+
self.gguf_writer.add_features_length (self.hparams["n_embd_features"])
2062+
self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"])
2063+
self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"])
2064+
self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"])
2065+
2066+
self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"])
2067+
self.gguf_writer.add_posnet_block_count (self.hparams["posnet"]["n_layer"])
2068+
2069+
self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"])
2070+
self.gguf_writer.add_convnext_block_count (self.hparams["convnext"]["n_layer"])
2071+
2072+
self.gguf_writer.add_causal_attention(False)
2073+
2074+
20302075
@Model.register("Qwen2MoeForCausalLM")
20312076
class Qwen2MoeModel(Model):
20322077
model_arch = gguf.MODEL_ARCH.QWEN2MOE

examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ else()
5151
add_subdirectory(speculative)
5252
add_subdirectory(speculative-simple)
5353
add_subdirectory(tokenize)
54+
add_subdirectory(tts)
5455
add_subdirectory(gen-docs)
5556
if (NOT GGML_BACKEND_DL)
5657
# these examples use the backends directly and cannot be built with dynamic loading

examples/llava/clip.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
896896
mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
897897
mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
898898
// stride = 1, padding = 1, bias is nullptr
899-
block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
899+
block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
900900

901901
// layer norm
902902
// // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
@@ -944,7 +944,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
944944
// block_2
945945
{
946946
// stride = 2
947-
block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
947+
block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
948948

949949
// block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
950950
// layer norm
@@ -1005,7 +1005,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
10051005
// mlp_2 ne [24, 24, 2048, 1]
10061006
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
10071007
// weight ne = [3, 3, 2048, 1]
1008-
struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
1008+
struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
10091009
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
10101010
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
10111011
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));

examples/tts/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-tts)
2+
add_executable(${TARGET} tts.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

0 commit comments

Comments
 (0)