Skip to content

feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling #10699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 1, 2025

Conversation

badayvedat
Copy link
Contributor

What does this PR do?

the compute_density_for_timestep_sampling function is defaulting to CPU-based tensor operations, which limits flexibility when working with GPUs. the function also lacks support for a generator, which made it difficult to ensure reproducibility in sampling.

Additional notes

No API-breaking changes; existing calls without device or generator params will default to CPU-based execution.

Before submitting

Who can review?

@a-r-r-o-w @sayakpaul

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, beat me to it! 🤗 Had to make the same change here for training to be perfectly reproducible but forgot about updating here too :/

logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
device: torch.device | str = "cpu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use Union here instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure! and i know the same change because i was trying to minimize the duplications in finetrainers 😄

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, thanks! Would be great to have a PR opened to finetrainers to replace it.

@sayakpaul sayakpaul merged commit 9f28f1a into huggingface:main Feb 1, 2025
12 checks passed
@badayvedat badayvedat deleted the feat/add-device-tpype branch February 1, 2025 22:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants