Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Positive Constraint Decoding PR #1 #15299

Closed
wants to merge 18 commits into from
Closed

[WIP] Positive Constraint Decoding PR #1 #15299

wants to merge 18 commits into from

Conversation

cwkeam
Copy link
Contributor

@cwkeam cwkeam commented Jan 23, 2022

Disjunctive Positive Constraint Decoding

@patrickvonplaten @LysandreJik @sgugger @patil-suraj @yjernite @thomwolf

Fixes #14081.

I apologize if this isn't a proper way to deal with feature contributions, but this is an incomplete PR request. I simply thought this was a good place to check-in and checkpoint on the progress & direction of the implementation. We can just work by keep adding commits to this PR request and progress until it's ready for final merge right?

Steps left:

  • Applying positive constraints disjunctively.
  • Writing tests

Here is an example of how one could use this functionality:

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers.generation_beam_constraints import (
    PhrasalConstraint
)
device = "cuda"

model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

force_text = " big monsters"
force_text_2 = " crazy"
force_tokens = tokenizer.encode(force_text, return_tensors="pt").to(device)[0]
force_tokens_2 = tokenizer.encode(force_text_2, return_tensors="pt").to(device)[0]

constraints = [
    PhrasalConstraint(force_tokens),
    PhrasalConstraint(force_tokens_2)
]

input_text = ["The baby is crying because"] 

model_inputs = tokenizer(input_text, return_tensors="pt")

for key, value in model_inputs.items():
    model_inputs[key] = value.to(device)

k = model.generate(
    **model_inputs,
    constraints=constraints,
    num_beams=7,
    num_return_sequences=7,
    no_repeat_ngram_size=2
)

for out in k:
    print(tokenizer.decode(out))

For some example outputs:

The baby is crying because she's been told crazy big monsters are going to come and kill her.
The baby is crying because she's been told crazy big monsters are coming for her.
The baby is crying because she's been told crazy big monsters are going to come after her.

1. General Constraint Framework

Users can define their own constraints by inheriting the Constraint interface class and this framework is ensured to work as desired, because the Constraint class is quite strictly defined. If an implementation passes the self.test() function of this interface then it necessarily works as desired. An incorrect implementation will lead to an error.

# https://github.com/cwkeam/transformers/blob/master/src/transformers/generation_beam_constraints.py#L16
class Constraint(ABC):
    r"""Abstract base class for all constraints that can be applied during generation.
    It must define how the constraint can be satisfied.

    All classes that inherit Constraint must follow the requirement that
    
    ```
    completed = False
    while(not completed):
        _, completed = constraint.update(constraint.advance())
    ```
    
    will always terminate (halt). 
    """
    def __init__(self):
        # test for the above condition
        self.test()

    def test(self):
        '''
        Tests whether this constraint has been properly defined.
        '''
        counter = 0
        completed = False
        while not completed:
            if counter == 1:
                self.reset()
            advance = self.advance()
            assert self.does_advance(advance)
            stepped, completed, reset = self.update(advance)
            counter += 1

            if counter > 10000:
                raise Exception("update() does not fulfill the constraint.")

        assert self.remaining() == 0        
        
    def advance(self):
        '''
        When called, returns the token that would take this constraint
        one step closer to being fulfilled.

        returns:
            token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
        '''
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    def does_advance(self, token_id: int):
        """
        Reads in a token and returns whether it creates progress.
        """
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    def update(self, token_id: int):
        """
        Reads in a token and returns booleans that indicate the progress made by it.
        This function will update the state of this object unlikes `does_advance(self, token_id: int)`.

        This isn't to test whether a certain token will advance the progress; it's to update its state
        as if it has been generated. This becomes important if token_id != desired token 
        (refer to else statement in PhrasalConstraint)

        Args:
            token_id(`int`):
                The id of a newly generated token in the beam search.
        returns:
            stepped(`boolean`):
                Whether this constraint has become one step closer to being fulfuilled.
            completed(`boolean`):
                Whether this constraint has been completely fulfilled by this token being generated.
            reset (`boolean`):
                Whether this constraint has reset its progress by this token being generated.
        """
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )
    
    def reset(self):
        """
        Resets the state of this constraint to its initialization.
        We would call this in cases where the fulfillment of a constraint is abrupted by an unwanted token.
        """
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    def remaining(self):
        '''
        Returns the number of remaining steps of `advance()` in order to complete this constraint.
        '''
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    def copy(self, stateful=False):
        '''
        Creates a new instance of this constraint.

        Args:
            stateful(`boolean`): Whether to not only copy the constraint for new instance, but also its state.
        Returns:
            constraint(`Constraint`): The same constraint as the one being called from.
        '''
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

For now, I've defined TokenConstraint for forcing the generation of a specific token and PhrasalContstraint for forcing the generation of a sequence of tokens that are not broken in the output. The example use of the latter is in the example code above.

2. model.generate() Mixin

# https://github.com/cwkeam/transformers/blob/master/src/transformers/generation_utils.py#L780
 def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        max_length: Optional[int] = None,
        ...
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
        constraints: Optional[List[Constraint]] = None,
        output_attentions: Optional[bool] = None,
        ...
        **model_kwargs,
    ) 

Leads to:

#https://github.com/cwkeam/transformers/blob/master/src/transformers/generation_utils.py#L1077

# 6. determine generation mode
is_constraint_gen_mode = constraints is not None
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None

Which ends up defining a ConstrainedBeamSearchScorer and initiates the beam search:

elif is_constraint_gen_mode:
  if num_return_sequences > num_beams:
      raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

  if stopping_criteria.max_length is None:
      raise ValueError("`max_length` needs to be a stopping_criteria for now.")

  # 10. prepare beam search scorer
  constrained_beam_scorer = ConstrainedBeamSearchScorer(
      constraints=constraints,
      batch_size=batch_size,
      ...,
  )
  # 11. interleave input_ids with `num_beams` additional sequences per batch
  input_ids, model_kwargs = self._expand_inputs_for_generation(
      input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
  )
  # 12. run beam search
  return self.constrained_beam_search(
      input_ids,
      constrained_beam_scorer=constrained_beam_scorer,
      ...
  )

3. Future Steps

1. Disjunctive Constraints

This doesn't yet do the Disjunctive decoding explained in the Issue #14081 . But this can be very easily implemented by simply defining a new Constraint sub-class. I will follow up with this with another commit.

2. Tests

I was unsure how to approach testing this generation function, especially since it's almost identical to the existing approaches, with just another step included that guides the generation.

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Jan 23, 2022

The documentation is not available anymore as the PR was closed or merged.

@cwkeam cwkeam changed the title Positive Constraint Decoding PR #1 [WIP] Positive Constraint Decoding PR #1 Jan 29, 2022
@cwkeam cwkeam closed this Jan 29, 2022
@cwkeam
Copy link
Contributor Author

cwkeam commented Jan 31, 2022

@patrickvonplaten Sorry for the confusion but I've closed this one and opened a new PR from a new branch with a lot more updates here #15416. Safe to leave this closed.

@cwkeam cwkeam closed this Jan 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Contribution] Disjunctive Positive Constraint Decoding (adding force_tokens to model.generate())
3 participants