Skip to content

Fix from_args_and_dict ProcessorMixin #38296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

yonigozlan
Copy link
Member

What does this PR do?

Fix issue where kwargs given to from_pretrained for processors would not be taken into account in the processor's __init__ directly, but would override the processor's attributes after the fact, resulting in unexpected behaviors.

This might be slightly breaking, as before the kwargs, and even the processor_dict items would be used even if not valid, but I guess this shouldn't be the case :)

@yonigozlan yonigozlan changed the title fix-from-args-and-dict-processormixin Fix from_args_and_dict ProcessorMixin May 22, 2025
@yonigozlan yonigozlan requested review from Cyrilvallez and ydshieh May 22, 2025 14:34
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh
Copy link
Collaborator

ydshieh commented May 22, 2025

I am not against it, but my memory tell me we do the same thing (i.e. kwargs is not used in init objects but set values afterwards). For example, ImageProcessingMixin.from_dict

        image_processor = cls(**image_processor_dict)

        # Update image_processor with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(image_processor, key):
                setattr(image_processor, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

So it's not clear to me:

  • if we want to apply the same changes to other classes
  • and I am lacking the context of resulting in unexpected behaviors. (I am sure there is, but I don't have concrete example in mind, so not able to judge if we need a fix , or if the current fix is the right solution)

And @zucchini-nlp could also be helpful on this PR I believe :-)

@yonigozlan
Copy link
Member Author

A concrete example given by @Cyrilvallez in Gemma3 processor:

class Gemma3Processor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    valid_kwargs = ["chat_template", "image_seq_length"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(
        self,
        image_processor,
        tokenizer,
        chat_template=None,
        image_seq_length: int = 256,
        **kwargs,
    ):
        print(image_seq_length)
        self.image_seq_length = image_seq_length
        self.image_token_id = tokenizer.image_token_id
        self.boi_token = tokenizer.boi_token
        self.image_token = tokenizer.image_token
        image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
        print(image_tokens_expanded)
        self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"

        super().__init__(
            image_processor=image_processor,
            tokenizer=tokenizer,
            chat_template=chat_template,
            **kwargs,
        )

If we instantiate a processor like this:

processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", image_seq_length=5)

image_seq_length in the init will still be 256 and then set to 5, but since self.full_image_sequence value is set in the init and depends on image_seq_length, we won't have the value we want for it.

I also just changed the PR to completely remove the manual valid_kwargs in the processors. I haven't looked at how other Mixin classes handle this in the repo, but I'm happy to make the same changes if we think that this is a good direction to take!

@zucchini-nlp
Copy link
Member

Yeah, I remember of this issue. IMO the case is specific to Gemma3 and better be fixed by deleting self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" and moving the logic to get expanded sequence to __call__, as we don in most other processors. WDYT?

@yonigozlan
Copy link
Member Author

This seems more aligned with what would be expected when adding kwargs to .from_pretrained call in my opinion. Plus if it allows setting more constants once and for all in the init instead of computing them at each call of __call__ , it seems like a win to me.
And I'd be glad to get rid of the valid_kwargs attributes in the processors :)

@zucchini-nlp
Copy link
Member

zucchini-nlp commented May 22, 2025

Oh yeah, valid_kwargs was added for some BC iirc. I'm oke with updating, only concern I have is that this will start raising errors when from_pretrained() get an unused kwargs. Most of the processors we have don't accept arbitrary kwargs I think

Let's make sure all processors have **kwargs at init :)

@yonigozlan
Copy link
Member Author

only concern I have is that this will start raising errors when from_pretrained() get an unused kwargs

# validate both processor_dict and given kwargs
unused_kwargs, valid_kwargs = cls.validate_init_kwargs(
processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs
)
# remove args that are in processor_dict to avoid duplicate arguments
args_to_remove = [i for i, arg in enumerate(accepted_args_and_kwargs) if arg in processor_dict]
args = [arg for i, arg in enumerate(args) if i not in args_to_remove]
# instantiate processor with used (and valid) kwargs only
processor = cls(*args, **valid_kwargs)

Since here we would be passing only "valid_kwargs" based on __init__ signature when instantiating the processor, we wouldn't raise an error, even if an invalid kwarg is passed to from_pretrained

@ydshieh
Copy link
Collaborator

ydshieh commented May 22, 2025

Very nice example. And OK if you want to deal with other mixin separately.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice to remove the valid_kwargs attribute as it was making the code unnecessarily complicated! 🤗 Glad that the dangerous behavior of setting the attrs after having initialized is removed as well 🤗

However to my understanding, args should take precedence!

Comment on lines +991 to +993
# remove args that are in processor_dict to avoid duplicate arguments
args_to_remove = [i for i, arg in enumerate(accepted_args_and_kwargs) if arg in processor_dict]
args = [arg for i, arg in enumerate(args) if i not in args_to_remove]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think args should take precedence on processor_dict no? As they are necesarily passed by the user directly no? and processor_dict may come from the config and not necesarily from the merged kwargs

So I think we should never remove args, but remove the corresponding values for them in the processor_dict if they were saved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "args" are given by the _get_arguments_from_pretrained, and are not really args given by the user, but attributes retrieved from a checkpoint, then reconstructed as an "args" list based on the attributes attribute of processors (again there seem to be some redundancy here that we could solve by inspecting the signature, I could have a look next :) ).

args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_args_and_dict(args, processor_dict, **kwargs)

So what can be specified by the user directly here can only be the kwargs, so I think they should take precedence, especially because of use cases such as this one:

processor = AutoProcessor.from_pretrained("checkpoint_path", image_processor=custom_image_processor)

where we want to get a processor from a checkpoint, but only modify one of its "attributes" e.g. image_processor, tokenizer or feature_extractor. This is currently not supported, and will just be silently ignored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha ok, alright then! Indeed the TLDR is: whatever the user explicitly passes should take precedence of course!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, thanks for clarifying! LGTM!

@yonigozlan yonigozlan merged commit 21b10d9 into huggingface:main May 28, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants