Skip to content

experiment(backend): autocast dtype in CustomLinear #7843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

psychedelicious
Copy link
Collaborator

@psychedelicious psychedelicious commented Mar 26, 2025

Summary

This resolves an issue where specifying float32 precision causes FLUX Fill to error.

I noticed that our other customized torch modules do some dtype casting themselves, so maybe this is a fine place to do this? Maybe this could break things...

See #7836

Related Issues / Discussions

Closes #7836

QA Instructions

Try various model combos. I don't know what I'm doing and this could be a Bad Idea™️.

To reproduce the problem in the linked issue, set precision: float32 in invokeai.yaml, then try to use FLUX Fill.

Merge Plan

n/a

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)
  • Updated What's New copy (if doing a release after this PR)

This resolves an issue where specifying `float32` precision causes FLUX Fill to error.

I noticed that our other customized torch modules do some dtype casting themselves, so maybe this is a fine place to do this? Maybe this could break things...

See #7836
@github-actions github-actions bot added python PRs that change python files backend PRs that change backend files labels Mar 26, 2025
@@ -73,6 +74,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)

weight = cast_to_dtype(weight, input.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine, but some models may act weirdly due to potential precision loss if we provide inputs with less precision than the model 🤔 In an ideal world I'd think we'd want to ensure the precision of the inputs are compatible with the model before calling it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files python PRs that change python files
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug]: Flux Fill inpainting does not work
2 participants