Skip to content

Commit 948ffff

Browse files
authored
RWKV: raise informative exception when attempting to manipulate past_key_values (#28600)
1 parent 9efec11 commit 948ffff

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

src/transformers/models/rwkv/modeling_rwkv.py

+18
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,24 @@ def get_output_embeddings(self):
778778
def set_output_embeddings(self, new_embeddings):
779779
self.head = new_embeddings
780780

781+
def generate(self, *args, **kwargs):
782+
# Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`.
783+
# RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and
784+
# usage).
785+
try:
786+
gen_output = super().generate(*args, **kwargs)
787+
except AttributeError as exc:
788+
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
789+
if "past_key_values" in str(exc):
790+
raise AttributeError(
791+
"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV "
792+
"doesn't have that attribute, try another generation strategy instead. For the available "
793+
"generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
794+
)
795+
else:
796+
raise exc
797+
return gen_output
798+
781799
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
782800
# only last token for inputs_ids if the state is passed along.
783801
if state is not None:

0 commit comments

Comments
 (0)