|
1 | 1 | # Stable Codec
|
2 |
| -This repository contains training and inference scripts for Stable Codec, introduced in the paper titled **Scaling Transformers for Low-bitrate High-Quality Speech Coding**. |
3 | 2 |
|
4 |
| -Paper: |
5 |
| -https://arxiv.org/abs/2411.19842 |
| 3 | +This repository contains training and inference scripts for models in the Stable Codec series, starting with `stable-codec-speech-16k` - introduced in the paper titled Scaling Transformers for Low-bitrate High-Quality Speech Coding. |
6 | 4 |
|
7 |
| -Sound demos: |
8 |
| -https://stability-ai.github.io/stable-codec-demo/ |
| 5 | +Paper: https://arxiv.org/abs/2411.19842 |
9 | 6 |
|
10 |
| -Weights & code will be released soon! |
| 7 | +Sound demos: https://stability-ai.github.io/stable-codec-demo/ |
| 8 | + |
| 9 | +## Additional training |
| 10 | + |
| 11 | +In addition to the training described in the paper, the released weights have also undergone 500k steps of finetuning with force-aligned phoneme data from LibriSpeech and the English portion Multilingual LibriSpeech. This was performed by using a CTC head to regress the phoneme categories from pre-bottleneck latents. We found that this additional training significantly boosted the applicability of the codec tokens to downstream tasks like TTS. |
| 12 | + |
| 13 | +## Install |
| 14 | + |
| 15 | +The model itself is defined in [stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) package. |
| 16 | + |
| 17 | +```bash |
| 18 | +pip install -r requirements.txt |
| 19 | +pip install -U flash-attn --no-build-isolation |
| 20 | +``` |
| 21 | + |
| 22 | +**IMPORTANT NOTE:** This model currently has a hard requirement for FlashAttention due to its use of sliding window attention. Inference without FlashAttention will likely be greatly degraded. |
| 23 | + |
| 24 | +## Encoding and decoding |
| 25 | + |
| 26 | +To encode audio or decode tokens, the `StableCodec` class provides a convenient wrapper for the model. It can be used with a local checkpoint and config as follows: |
| 27 | + |
| 28 | +```python |
| 29 | +from model import StableCodec |
| 30 | + |
| 31 | +model = StableCodec( |
| 32 | + model_config_path="<path-to-model-config>", |
| 33 | + ckpt_path="<path-to-checkpoint>", # optional, can be `None` |
| 34 | +) |
| 35 | + |
| 36 | +audiopath = "audio.wav" |
| 37 | + |
| 38 | +latents, tokens = model.encode(audiopath) |
| 39 | +decoded_audio = model.decode(tokens) |
| 40 | + |
| 41 | +torchaudio.save("decoded.wav", decoded_audio, model.sample_rate) |
| 42 | +``` |
| 43 | + |
| 44 | +To download the model weights automatically from HuggingFace, simply provide the model name: |
| 45 | + |
| 46 | +```python |
| 47 | +model = StableCodec( |
| 48 | + pretrained_model = 'stabilityai/stable-codec-speech-16k' |
| 49 | +) |
| 50 | +``` |
| 51 | +### Posthoc bottleneck configuration |
| 52 | + |
| 53 | +Most usecases will benefit from replacing the training-time FSQ bottleneck with a post-hoc FSQ bottleneck, as described in the paper. This allows token dictionary size to be reduced to a reasonable level for modern language models. This is achieved by calling the `set_posthoc_bottleneck` function, and setting a flag to the encode/decode calls: |
| 54 | + |
| 55 | +```python |
| 56 | +model.set_posthoc_bottleneck("2x15625_700bps") |
| 57 | +latents, tokens = model.encode(audiopath, posthoc_bottleneck = True) |
| 58 | +decoded_audio = model.decode(tokens, posthoc_bottleneck = True) |
| 59 | +``` |
| 60 | +`set_posthoc_bottleneck` can take a string as argument, which allows selection a number of recommended preset settings for the bottleneck: |
| 61 | + |
| 62 | +| Bottleneck Preset | Number of Tokens per step | Dictionary Size | Bits Per Second (bps) | |
| 63 | +|-------------------|------------------|-----------------|-----------------------| |
| 64 | +| `1x46656_400bps` | 1 | 46656 | 400 | |
| 65 | +| `2x15625_700bps` | 2 | 15625 | 700 | |
| 66 | +| `4x729_1000bps` | 4 | 729 | 1000 | |
| 67 | + |
| 68 | +Alternatively, the bottleneck stages can be specified directly. The format for specifying this can be seen in the definition of the `StableCodec` class in `model.py`. |
| 69 | + |
| 70 | +### Normalization |
| 71 | + |
| 72 | +The model is trained with utterances normalized to -20 LUFS. The `encode` function applies this by default, but it can be disabled by setting `normalize = False` when calling the function. |
| 73 | + |
| 74 | +## Finetune |
| 75 | + |
| 76 | +To finetune a model given its config and checkpoint, execute `train.py` file: |
| 77 | + |
| 78 | +```bash |
| 79 | +python train.py \ |
| 80 | + --project "stable-codec" \ |
| 81 | + --name "finetune" \ |
| 82 | + --config-file "defaults.ini" \ |
| 83 | + --save-dir "<ckpt-save-dir>" \ |
| 84 | + --model-config "<path-to-config.json>" \ |
| 85 | + --dataset-config "<dataset-config.json>" \ |
| 86 | + --val-dataset-config "<dataset-config.json>" \ |
| 87 | + --pretrained-ckpt-path "<pretrained-model-ckpt.ckpt>" \ |
| 88 | + --ckpt-path "$CKPT_PATH" \ |
| 89 | + --num-nodes $SLURM_JOB_NUM_NODES \ |
| 90 | + --num-workers 16 --batch-size 10 --precision "16-mixed" \ |
| 91 | + --checkpoint-every 10000 \ |
| 92 | + --logger "wandb" |
| 93 | +``` |
| 94 | + |
| 95 | +For dataset configuration, refer to `stable-audio-tools` [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md). |
| 96 | + |
| 97 | + |
| 98 | +### Using CTC loss |
| 99 | + |
| 100 | +To use [CTC loss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) |
| 101 | +during training you have to enable it in the training configuration file |
| 102 | +and in the training dataset configuration. |
| 103 | + |
| 104 | +1. Modifying training configuration: |
| 105 | + - Enable CTC projection head and set its hidden dimension: |
| 106 | + ```python |
| 107 | + config["model"]["use_proj_head"] = True |
| 108 | + config["model"]["proj_head_dim"] = 81 |
| 109 | + ``` |
| 110 | + - Enable CTC in the training part of the config: |
| 111 | + ```python |
| 112 | + config["training"]["use_ctc"] = True |
| 113 | + ``` |
| 114 | + - And set its loss config: |
| 115 | + ```python |
| 116 | + config["training"]["loss_configs"]["ctc"] = { |
| 117 | + "blank_idx": 80, |
| 118 | + "decay": 1.0, |
| 119 | + "weights": {"ctc": 1.0} |
| 120 | + } |
| 121 | + ``` |
| 122 | + - Optionally, you can enable computation of the Phone-Error-Rate (PER) during validation: |
| 123 | + ```python |
| 124 | + config["training"]["eval_loss_configs"]["per"] = {} |
| 125 | + ``` |
| 126 | + |
| 127 | +2. Configuring dataset (only WebDataset format is supported for CTC): |
| 128 | + - The dataset configuration should have one additional field set to it (see [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md) for other options): |
| 129 | + ```python |
| 130 | + config["force_align_text"] = True |
| 131 | + ``` |
| 132 | + - And the JSON metadata file for each sample should contain force aligned transcript under `force_aligned_text` entry in the format specified below (besides other metadata). |
| 133 | + Where `transcript` is a list of word-level alignments with `start` and `end` fields specifying range **in seconds** of each word. |
| 134 | + ```json |
| 135 | + "normalized_text":"and i feel" |
| 136 | + "force_aligned_text":{ |
| 137 | + "transcript":[ |
| 138 | + { |
| 139 | + "word":"and", |
| 140 | + "start":0.2202, |
| 141 | + "end":0.3403 |
| 142 | + }, |
| 143 | + { |
| 144 | + "word":"i", |
| 145 | + "start":0.4604, |
| 146 | + "end":0.4804 |
| 147 | + }, |
| 148 | + { |
| 149 | + "word":"feel", |
| 150 | + "start":0.5204, |
| 151 | + "end":0.7006 |
| 152 | + } |
| 153 | + ] |
| 154 | + } |
| 155 | + ``` |
0 commit comments