Skip to content

lucarinelli/conditional_text_generation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Conditional text generation

How to use

Install the necessary packages

Before being able to use our models and run the generation script, install the necessary dependencies by running the following command:

pip install torch transformers wandb tokenizers

If some of these packages have already been installed, they can be skipped.

Run the the generation script

To execute the generation script, run the following command:

python ./generation_script.py --model MODEL --input INPUT --control_codes CONTROL_CODES --max_len MAX_LEN --temperature TEMPERATURE --top_k TOP_K --repetition_penalty REPETITION_PENALTY --top_p TOP_P --num_returned_sequences NUM_RETURNED_SEQUENCES Where:

  • MODEL is one of [ST,SEP,ST-10F,SEP-10F,D-ST,T-ST,ST-0] where:
    • ST is GPT-2 using control codes (12 layers);
    • SEP is GPT-2 using serparators (12 layers);
    • ST-10 is GPT-2 using control codes with 10 layers freezed (12 layers);
    • SEP-10 is GPT-2 using separators with the first 10 layers freezed (12 layers);
    • D-T is Distil-GPT-2 using control codes (6 layers);
    • T-ST is Tiny-GPT-2 using control codes (2 layers);
    • ST-0 is GPT-2 using control codes and trained from scratch (12 layers);
  • INPUT is the sequence the model will start generating from. Note that empty inputs are only allowed on SEP and SP-0, for further details, please refer to the Tokenization section of the paper.
  • CONTROL CODES is a list (possibly empty) of control codes to influence the text genration. For better results it is suggested to use one (or more) of the following:

    ['kitchen', 'food', 'animal', 'furniture', 'indoor', 'accessory', 'person', 'vehicle', 'outdoor', 'sports', 'appliance', 'electronic']
  • MAX_LEN (int, optional, defaults to 16) is the maximum number of tokens (words) the model will generate, including the input text.
  • TEMPERATURE (float, optional, defaults to 0.9) – The value used to module the next token probabilities.
  • TOP_K (int, optional, defaults to 30) – The number of highest probability vocabulary tokens to keep for top-k-filtering.
  • REPETITION PENALTY (float, optional, defaults to 2.0) – The parameter for repetition penalty. 1.0 means no penalty.
  • TOP_P (float, optional, defaults to 0.7) – If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
  • NUM_RETURNED_SEQUENCES (int, optional, defaults to 3) – The number of independently computed returned sequences for each element in the batch.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •