Skip to content

Commit

Permalink
fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Aug 9, 2024
1 parent 185c1cd commit 209fccc
Showing 1 changed file with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -915,13 +915,17 @@ def prepare_inputs_for_generation(
if past_length > 0:
position_ids = position_ids[:, past_length:]

if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()}

if cache_position is not None:
cache_position = cache_position[-position_ids.shape[1] :]
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

model_inputs.update(
{
Expand Down

0 comments on commit 209fccc

Please sign in to comment.