diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 028550be9e7618..cd293399b5ac02 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -915,13 +915,17 @@ def prepare_inputs_for_generation( if past_length > 0: position_ids = position_ids[:, past_length:] - if inputs_embeds is not None: - model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]} + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()} - - if cache_position is not None: - cache_position = cache_position[-position_ids.shape[1] :] + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs.update( {