Skip to content

Audio decoding support: range-based core API #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 12, 2025

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Mar 6, 2025

This PR adds the get_frames_by_pts_in_range_audio(start_seconds, stop_seconds=None) -> Tensor core API.

  • It returns a 2D tensor of shape (num_channels, num_samples).
    • We don't return pts or duration. We might eventually do it, I'm open to it, but those are directly deductible from start_seconds and the sample rate, so I'm leaving it out.
    • We're not returning something of shape (e.g.) (num_frames, num_channels, num_samples_per_frame), and we never will, because audio frames generally contain variable number of samples. I found out the hard way.
  • It is a frame-based API. That's OK for a core API. The public decoder method will be sample-based, but that's out of scope for now.
  • It allows consecutive calls, but only if it doesn't require a backwards seek (that'll be implemented later). See the tests.
  • stop_seconds is None by default, so that users can decode to the end of the file without knowing what the duration is. setting stop_seconds=<some super high value> doesn't raise an error either. (Note that IMHO we should extend this to all range-based APIs: Extend SimpleVideoDecoder index-based APIs #150)

We've discussed a lot offline already, so I won't be writing down everything that went into the design decisions here. But once everything is done, I'll make sure to write down a note in the code that documents why audio decoding is implemented the way it is.


Preliminary benchmarks against torchaudio (built from source) are very promising, even for this non-optimized first version:

Duration: 13s
torchcodec: med = 8.01ms +- 1.03
torchaudio: med = 11.90ms +- 0.60

Duration: 13s
torchcodec: med = 4.07ms +- 0.73
torchaudio: med = 7.22ms +- 0.63

Duration: 2m11s
torchcodec: med = 30.18ms +- 0.75
torchaudio: med = 45.17ms +- 1.85

Duration: 1h27m
torchcodec: med = 1060.43ms +- 23.78
torchaudio: med = 1746.49ms +- 22.55

Code:

from torchcodec.decoders import _core as core
from torchaudio.io import StreamReader
import torch
from time import perf_counter_ns



def bench(f, *args, num_exp=100, warmup=0, **kwargs):

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()

def report_stats(times, unit="ms", prefix=""):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{prefix}: {med = :.2f}{unit} +- {std:.2f}")
    return med



def codec(path, stream_index):
    decoder = core.create_from_file(path, seek_mode="approximate")
    core.add_audio_stream(decoder, stream_index=stream_index)

    core.get_frames_by_pts_in_range_audio(decoder, start_seconds=0, stop_seconds=None)

def audio(path, stream_index):
    reader = StreamReader(path)
    reader.add_audio_stream(frames_per_chunk=1024, stream_index=stream_index)
    for _ in reader.stream():
        pass

NUM_EXP = 30
WARMUP = 1
for path, stream_index, duration in (
    ("/home/nicolashug/dev/torchcodec/test/resources/nasa_13013.mp4", 4, "13s"),
    ("/home/nicolashug/dev/torchcodec/test/resources/nasa_13013.mp4.audio.mp3", 0, "13s"),
    ("/home/nicolashug/test_videos/long.mp4", 1, "2m11s"),
    ("/home/nicolashug/test_videos/output.mp4", 1, "1h27m"),
):
    print(f"Duration: {duration}")
    times = bench(codec, path, stream_index, num_exp=NUM_EXP, warmup=WARMUP)
    report_stats(times, prefix="torchcodec")

    times = bench(audio, path, stream_index, num_exp=NUM_EXP, warmup=WARMUP)
    report_stats(times, prefix="torchaudio")
    print()

There will be plenty of follow-ups, mainly:

  • enable backwards seeks
  • enable audio formats other than fltp
  • enable a user-defined sample_rate
  • expose a public method in AudioDecoder
  • maybe, maybe not: something like the normalize parameter of the torchaudio reader, which allows users to specify whether they want a float tensor in [-1, 1], or a tensor with the same dtype as the audio format. We'll figure that out later. We'll probably always return float tensors by default anyway.
  • perf: try to pre-allocate the output tensor and save copies

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 6, 2025
tensors.push_back(frameOutput.data);
} catch (const EndOfFileException& e) {
reachedEOF = true;
}
Copy link
Member Author

@NicolasHug NicolasHug Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q about C++ best practices: I realize we're already doing it in a few places (like custom ops), but is it a good practice to use exceptions for control flow? Maybe the reachedEOF flag from decodeAVFrame() could be a stateful attribute instead? (not that I find statefulness appealing either!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we shouldn't ordinarily decode past the end of a file, I think it makes sense for us to throw exceptions when we reach the end of a file. Here, we're not really using an exception for control flow per se. That is, we're not trying to read past the end of the file, we just have to handle the case that we might.

With that said, I do find it more natural when the "normal" stop conditions are explicitly part of the while loop's condition, as opposed to setting a boolean inside the loop. But since you're depending on the internal state of the decoder to know the last decoded frame info, I don't know if that's possible. When I implemented something similar, I ended up using a priming read to get around this problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, nit: for local variables used in a small space with a clear purpose, I prefer shorter names. So even stop as the boolean would make this easier for me to read.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I'll use finished instead of stop, because we already have local variables named stopPts and stopSeconds in this function.

asset.duration_seconds
)
assert decoder.metadata.sample_rate == asset.sample_rate
assert decoder.metadata.num_channels == asset.num_channels
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes here are mostly a drive-by. I'm extending this one because I removed test_audio_get_json_metadata which was outdated.

@NicolasHug NicolasHug changed the title Audio decoding support - range API Audio decoding support: range-based core API Mar 9, 2025
@NicolasHug NicolasHug marked this pull request as ready for review March 9, 2025 13:22

return get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is complex enough and I didn't want to obfuscate it further with pts-to-index conversions, so I created this stateless helper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this helper get its reference from the the test utils? Right now it's getting frames by decoding the file, which means we're not actually comparing against a reference - or am I missing something here?

Copy link
Member Author

@NicolasHug NicolasHug Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct. This test compares a stateless decoder (treated as the ref) with a stateful decoder (which is what users interact with). So it still asserts what we need to assert.

Relying on the references would mean converting all the timestamps into indices, and as mentioned in my comment just above, I wanted to avoid complicating this test further.

Eventually we will update this test (i.e. when we enable backwards seeking), at which point we could just rely on the reference frames.

I was hoping this would make reviewing easier, but apparently it's bringing more confusion :p .If you prefer converting to indices straight away, let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh, okay, I had misunderstood what you meant about avoiding the conversions. It's fine to leave as-is, but let's put in a comment explaining that we're comparing a decoder which only seeks once to a decoder which seeks multiple times along with a TODO to convert it to loading the reference frames from indices.

if frame_info.pts_seconds
<= pts_seconds
< frame_info.pts_seconds + frame_info.duration_seconds
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit I don't completely understand what's going on here. I get that we have a generator that does a linear walk through the frames, searching for the first frame that meets our conditions, but how does passing that generator to next() get us back the one index we want?

Also, it may be better to do this once, in __post_init__() and set up a mapping rather than doing it on every call.

Copy link
Member Author

@NicolasHug NicolasHug Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we can't build a mapping for this because we are mapping contiguous timestamps to integers. I'll add a comment that bisect might make things faster if needed (although for such small arrays, the tests run very fast).

next(it) just returns the very first entry in it, whether it is a list, a tuple, a generator, etc. And then we build the generator in such a way that the first frame in the generator is the frame we want. What's going on is exactly the same as the snippet below:

>>> gen = (i for i in range(10) if i > 5)
>>> next(gen)
6

The values in [0, 5] were never part of gen, they were only part of the output from range(). The first entry in gen is 6.

how does passing that generator to next() get us back the one index we want?

We iterate over (frame_index, frame_info) tuples, filter those that don't meet our condition, and then only store frame_index within the generator:

>>> list(i for (i, j) in zip((1, 2, 3), ("a", "b", "c")))
[1, 2, 3]
>>> list(j for (i, j) in zip((1, 2, 3), ("a", "b", "c")))
['a', 'b', 'c']

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow! I never thought about it that way, in that next() on a generator is effectively like lst[0] on a sequence.

@NicolasHug NicolasHug merged commit ff4abff into pytorch:main Mar 12, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants