| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Example command with bag of words: |
| | python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95 |
| | |
| | Example command with discriminator: |
| | python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 |
| | """ |
| |
|
| | import argparse |
| | import json |
| | from operator import add |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.autograd import Variable |
| | from tqdm import trange |
| | from transformers import GPT2Tokenizer |
| | from transformers.file_utils import cached_path |
| | from transformers.modeling_gpt2 import GPT2LMHeadModel |
| |
|
| | from pplm_classification_head import ClassificationHead |
| |
|
| | import nltk |
| | nltk.download('words') |
| | nltk.download('stopwords') |
| | nltk.download('names') |
| | import nltk.corpus as corpus |
| | from nltk.corpus import words as words_corpus |
| |
|
| | PPLM_BOW = 1 |
| | PPLM_DISCRIM = 2 |
| | PPLM_BOW_DISCRIM = 3 |
| | SMALL_CONST = 1e-15 |
| | BIG_CONST = 1e10 |
| |
|
| | QUIET = 0 |
| | REGULAR = 1 |
| | VERBOSE = 2 |
| | VERY_VERBOSE = 3 |
| | VERBOSITY_LEVELS = { |
| | 'quiet': QUIET, |
| | 'regular': REGULAR, |
| | 'verbose': VERBOSE, |
| | 'very_verbose': VERY_VERBOSE, |
| | } |
| |
|
| | BAG_OF_WORDS_ARCHIVE_MAP = { |
| | 'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt", |
| | 'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt", |
| | 'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt", |
| | 'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt", |
| | 'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt", |
| | 'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt", |
| | 'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt", |
| | 'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt", |
| | 'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt", |
| | } |
| |
|
| | DISCRIMINATOR_MODELS_PARAMS = { |
| | "clickbait": { |
| | "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt", |
| | "class_size": 2, |
| | "embed_size": 1024, |
| | "class_vocab": {"non_clickbait": 0, "clickbait": 1}, |
| | "default_class": 1, |
| | "pretrained_model": "gpt2-medium", |
| | }, |
| | "sentiment": { |
| | "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt", |
| | "class_size": 5, |
| | "embed_size": 1024, |
| | "class_vocab": {"very_positive": 2, "very_negative": 3}, |
| | "default_class": 3, |
| | "pretrained_model": "gpt2-medium", |
| | }, |
| | "3_PerSoothe": { |
| | "path": "/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_opt_lowlr_medgpt/3_PerSoothe_classifier_head_epoch_10.pt", |
| | "class_size": 3, |
| | "embed_size": 1024, |
| | "class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
| | "default_class": 2, |
| | "pretrained_model": "microsoft/DialoGPT-medium", |
| | }, |
| | "3_PerSoothe_eot": { |
| | "path": "/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_opt_eot_lowlr_medgpt/3_PerSoothe_classifier_head_epoch_10.pt", |
| | "class_size": 3, |
| | "embed_size": 1024, |
| | "class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
| | "default_class": 2, |
| | "pretrained_model": "microsoft/DialoGPT-medium", |
| | }, |
| | "3_PerSoothe_lrg": { |
| | "class_size": 3, |
| | "embed_size": 1280, |
| | "class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
| | "default_class": 2, |
| | "pretrained_model": "microsoft/DialoGPT-large", |
| | }, |
| | "3_PerSoothe_med": { |
| | "class_size": 3, |
| | "embed_size": 1024, |
| | "class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
| | "default_class": 2, |
| | "pretrained_model": "microsoft/DialoGPT-medium", |
| | }, |
| | "2_PerSoothe_lrg": { |
| | "class_size": 2, |
| | "embed_size": 1280, |
| | "class_vocab": {"soothes": 0, "neutral": 1}, |
| | "default_class": 2, |
| | "pretrained_model": "microsoft/DialoGPT-large", |
| | }, |
| | "2_PerSoothe_med": { |
| | "class_size": 2, |
| | "embed_size": 1024, |
| | "class_vocab": {"soothes": 0, "neutral": 1}, |
| | "default_class": 2, |
| | "pretrained_model": "microsoft/DialoGPT-medium", |
| | }, |
| | } |
| |
|
| |
|
| | def to_var(x, requires_grad=False, volatile=False, device='cuda'): |
| | if torch.cuda.is_available() and device == 'cuda': |
| | x = x.cuda() |
| | elif device != 'cuda': |
| | x = x.to(device) |
| | return Variable(x, requires_grad=requires_grad, volatile=volatile) |
| |
|
| |
|
| | def top_k_filter(logits, k, probs=False): |
| | """ |
| | Masks everything but the k top entries as -infinity (1e10). |
| | Used to mask logits such that e^-infinity -> 0 won't contribute to the |
| | sum of the denominator. |
| | """ |
| | if k == 0: |
| | return logits |
| | else: |
| | values = torch.topk(logits, k)[0] |
| | batch_mins = values[:, -1].view(-1, 1).expand_as(logits) |
| | if probs: |
| | return torch.where(logits < batch_mins, |
| | torch.ones_like(logits) * 0.0, logits) |
| | return torch.where(logits < batch_mins, |
| | torch.ones_like(logits) * -BIG_CONST, |
| | logits) |
| |
|
| |
|
| | def perturb_past( |
| | past, |
| | model, |
| | last, |
| | unpert_past =None, |
| | unpert_logits=None, |
| | accumulated_hidden=None, |
| | grad_norms=None, |
| | stepsize=0.01, |
| | one_hot_bows_vectors=None, |
| | classifier=None, |
| | class_label=None, |
| | loss_type=0, |
| | num_iterations=3, |
| | horizon_length=1, |
| | window_length=0, |
| | decay=False, |
| | gamma=1.5, |
| | kl_scale=0.01, |
| | device='cuda', |
| | verbosity_level=REGULAR |
| | ): |
| | |
| | grad_accumulator = [ |
| | (np.zeros(p.shape).astype("float32")) |
| | for p in past |
| | ] |
| |
|
| | if accumulated_hidden is None: |
| | accumulated_hidden = 0 |
| |
|
| | if decay: |
| | decay_mask = torch.arange( |
| | 0., |
| | 1.0 + SMALL_CONST, |
| | 1.0 / (window_length) |
| | )[1:] |
| | else: |
| | decay_mask = 1.0 |
| |
|
| | |
| | |
| | _, _, _, curr_length, _ = past[0].shape |
| |
|
| | if curr_length > window_length and window_length > 0: |
| | ones_key_val_shape = ( |
| | tuple(past[0].shape[:-2]) |
| | + tuple([window_length]) |
| | + tuple(past[0].shape[-1:]) |
| | ) |
| |
|
| | zeros_key_val_shape = ( |
| | tuple(past[0].shape[:-2]) |
| | + tuple([curr_length - window_length]) |
| | + tuple(past[0].shape[-1:]) |
| | ) |
| |
|
| | ones_mask = torch.ones(ones_key_val_shape) |
| | ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) |
| | ones_mask = ones_mask.permute(0, 1, 2, 4, 3) |
| |
|
| | window_mask = torch.cat( |
| | (ones_mask, torch.zeros(zeros_key_val_shape)), |
| | dim=-2 |
| | ).to(device) |
| | else: |
| | window_mask = torch.ones_like(past[0]).to(device) |
| |
|
| | |
| | loss_per_iter = [] |
| | new_accumulated_hidden = None |
| | for i in range(num_iterations): |
| | if verbosity_level >= VERBOSE: |
| | print("Iteration ", i + 1) |
| | curr_perturbation = [ |
| | to_var(torch.from_numpy(p_), requires_grad=True, device=device) |
| | for p_ in grad_accumulator |
| | ] |
| |
|
| | |
| | perturbed_past = list(map(add, past, curr_perturbation)) |
| | _, _, _, curr_length, _ = curr_perturbation[0].shape |
| | all_logits, _, all_hidden = model(last, past_key_values=perturbed_past) |
| | hidden = all_hidden[-1] |
| | new_accumulated_hidden = accumulated_hidden + torch.sum( |
| | hidden, |
| | dim=1 |
| | ).detach() |
| | |
| | logits = all_logits[:, -1, :] |
| | probs = F.softmax(logits, dim=-1) |
| |
|
| | loss = 0.0 |
| | loss_list = [] |
| | if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: |
| | for one_hot_bow in one_hot_bows_vectors: |
| | bow_logits = torch.mm(probs, torch.t(one_hot_bow)) |
| | bow_loss = -torch.log(torch.sum(bow_logits)) |
| | loss += bow_loss |
| | loss_list.append(bow_loss) |
| | if verbosity_level >= VERY_VERBOSE: |
| | print(" pplm_bow_loss:", loss.data.cpu().numpy()) |
| |
|
| | if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM: |
| | ce_loss = torch.nn.CrossEntropyLoss() |
| | |
| | curr_unpert_past = unpert_past |
| | curr_probs = torch.unsqueeze(probs, dim=1) |
| | wte = model.resize_token_embeddings() |
| | for _ in range(horizon_length): |
| | inputs_embeds = torch.matmul(curr_probs, wte.weight.data) |
| | _, curr_unpert_past, curr_all_hidden = model( |
| | past_key_values=curr_unpert_past, |
| | inputs_embeds=inputs_embeds |
| | ) |
| | curr_hidden = curr_all_hidden[-1] |
| | new_accumulated_hidden = new_accumulated_hidden + torch.sum( |
| | curr_hidden, dim=1) |
| |
|
| | prediction = classifier(new_accumulated_hidden / |
| | (curr_length + 1 + horizon_length)) |
| |
|
| | label = torch.tensor(prediction.shape[0] * [class_label], |
| | device=device, |
| | dtype=torch.long) |
| | discrim_loss = ce_loss(prediction, label) |
| | if verbosity_level >= VERY_VERBOSE: |
| | print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) |
| | loss += discrim_loss |
| | loss_list.append(discrim_loss) |
| |
|
| | kl_loss = 0.0 |
| | if kl_scale > 0.0: |
| | unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) |
| | unpert_probs = ( |
| | unpert_probs + SMALL_CONST * |
| | (unpert_probs <= SMALL_CONST).float().to(device).detach() |
| | ) |
| | correction = SMALL_CONST * (probs <= SMALL_CONST).float().to( |
| | device).detach() |
| | corrected_probs = probs + correction.detach() |
| | kl_loss = kl_scale * ( |
| | (corrected_probs * (corrected_probs / unpert_probs).log()).sum() |
| | ) |
| | if verbosity_level >= VERY_VERBOSE: |
| | print(' kl_loss', kl_loss.data.cpu().numpy()) |
| | loss += kl_loss |
| |
|
| | loss_per_iter.append(loss.data.cpu().numpy()) |
| | if verbosity_level >= VERBOSE: |
| | print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) |
| |
|
| | |
| | loss.backward(retain_graph=True) |
| |
|
| | |
| | if grad_norms is not None and loss_type == PPLM_BOW: |
| | grad_norms = [ |
| | torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) |
| | for index, p_ in enumerate(curr_perturbation) |
| | ] |
| | else: |
| | grad_norms = [ |
| | (torch.norm(p_.grad * window_mask) + SMALL_CONST) |
| | for index, p_ in enumerate(curr_perturbation) |
| | ] |
| |
|
| | |
| | grad = [ |
| | -stepsize * |
| | (p_.grad * window_mask / grad_norms[ |
| | index] ** gamma).data.cpu().numpy() |
| | for index, p_ in enumerate(curr_perturbation) |
| | ] |
| |
|
| | |
| | grad_accumulator = list(map(add, grad, grad_accumulator)) |
| |
|
| | |
| | for p_ in curr_perturbation: |
| | p_.grad.data.zero_() |
| |
|
| | |
| | new_past = [] |
| | for p_ in past: |
| | new_past.append(p_.detach()) |
| | past = new_past |
| |
|
| | |
| | grad_accumulator = [ |
| | to_var(torch.from_numpy(p_), requires_grad=True, device=device) |
| | for p_ in grad_accumulator |
| | ] |
| | pert_past = list(map(add, past, grad_accumulator)) |
| |
|
| | return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter |
| |
|
| |
|
| | def get_classifier( |
| | name: Optional[str], |
| | class_label: Union[str, int], |
| | device: str, |
| | verbosity_level: int = REGULAR, |
| | fp: str = None, |
| | is_deep: bool = False, |
| | is_deeper: bool =False |
| | ) -> Tuple[Optional[ClassificationHead], Optional[int]]: |
| | if name is None: |
| | return None, None |
| |
|
| | params = DISCRIMINATOR_MODELS_PARAMS[name] |
| | classifier = ClassificationHead( |
| | class_size=params['class_size'], |
| | embed_size=params['embed_size'], |
| | is_deep=is_deep, |
| | is_deeper=is_deeper |
| | ).to(device) |
| | if "url" in params: |
| | resolved_archive_file = cached_path(params["url"]) |
| | elif "path" in params: |
| | resolved_archive_file = params["path"] |
| | elif fp != None: |
| | resolved_archive_file = fp |
| | else: |
| | raise ValueError("Either url or path have to be specified " |
| | "in the discriminator model parameters") |
| | classifier.load_state_dict( |
| | torch.load(resolved_archive_file, map_location=device)) |
| | classifier.eval() |
| |
|
| | if isinstance(class_label, str): |
| | if class_label in params["class_vocab"]: |
| | label_id = params["class_vocab"][class_label] |
| | else: |
| | label_id = params["default_class"] |
| | if verbosity_level >= REGULAR: |
| | print("class_label {} not in class_vocab".format(class_label)) |
| | print("available values are: {}".format(params["class_vocab"])) |
| | print("using default class {}".format(label_id)) |
| |
|
| | elif isinstance(class_label, int): |
| | if class_label in set(params["class_vocab"].values()): |
| | label_id = class_label |
| | else: |
| | label_id = params["default_class"] |
| | if verbosity_level >= REGULAR: |
| | print("class_label {} not in class_vocab".format(class_label)) |
| | print("available values are: {}".format(params["class_vocab"])) |
| | print("using default class {}".format(label_id)) |
| |
|
| | else: |
| | label_id = params["default_class"] |
| |
|
| | return classifier, label_id |
| |
|
| |
|
| | def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ |
| | List[List[List[int]]]: |
| | bow_indices = [] |
| | for id_or_path in bag_of_words_ids_or_paths: |
| | if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: |
| | filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path]) |
| | else: |
| | filepath = id_or_path |
| | with open(filepath, "r") as f: |
| | words = f.read().strip().split("\n") |
| | bow_indices.append( |
| | [tokenizer.encode(word.strip(), |
| | add_prefix_space=True, |
| | add_special_tokens=False) |
| | for word in words]) |
| | return bow_indices |
| |
|
| |
|
| | def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'): |
| | if bow_indices is None: |
| | return None |
| |
|
| | one_hot_bows_vectors = [] |
| | for single_bow in bow_indices: |
| | single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) |
| | single_bow = torch.tensor(single_bow).to(device) |
| | num_words = single_bow.shape[0] |
| | one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device) |
| | one_hot_bow.scatter_(1, single_bow, 1) |
| | one_hot_bows_vectors.append(one_hot_bow) |
| | return one_hot_bows_vectors |
| |
|
| |
|
| | def full_text_generation( |
| | model, |
| | tokenizer, |
| | context=None, |
| | num_samples=1, |
| | device="cuda", |
| | bag_of_words=None, |
| | discrim=None, |
| | class_label=None, |
| | length=100, |
| | stepsize=0.02, |
| | temperature=1.0, |
| | top_k=10, |
| | sample=True, |
| | num_iterations=3, |
| | grad_length=10000, |
| | horizon_length=1, |
| | window_length=0, |
| | decay=False, |
| | gamma=1.5, |
| | gm_scale=0.9, |
| | kl_scale=0.01, |
| | verbosity_level=REGULAR, |
| | fp=None, |
| | is_deep=False, |
| | is_deeper=False, |
| | stop_eot=False, |
| | **kwargs |
| | ): |
| | classifier, class_id = get_classifier( |
| | discrim, |
| | class_label, |
| | device, |
| | REGULAR, |
| | fp, |
| | is_deep, |
| | is_deeper |
| | ) |
| |
|
| | bow_indices = [] |
| | if bag_of_words: |
| | bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), |
| | tokenizer) |
| |
|
| | if bag_of_words and classifier: |
| | loss_type = PPLM_BOW_DISCRIM |
| | if verbosity_level >= REGULAR: |
| | print("Both PPLM-BoW and PPLM-Discrim are on. " |
| | "This is not optimized.") |
| |
|
| | elif bag_of_words: |
| | loss_type = PPLM_BOW |
| | if verbosity_level >= REGULAR: |
| | print("Using PPLM-BoW") |
| |
|
| | elif classifier is not None: |
| | loss_type = PPLM_DISCRIM |
| | if verbosity_level >= REGULAR: |
| | print("Using PPLM-Discrim") |
| |
|
| | else: |
| | raise Exception("Specify either a bag of words or a discriminator") |
| |
|
| | unpert_gen_tok_text, _, _, _ = generate_text_pplm( |
| | model=model, |
| | tokenizer=tokenizer, |
| | context=context, |
| | device=device, |
| | length=length, |
| | sample=sample, |
| | perturb=False, |
| | verbosity_level=verbosity_level, |
| | stop_eot=stop_eot |
| | ) |
| | if device == 'cuda': |
| | torch.cuda.empty_cache() |
| |
|
| | pert_gen_tok_texts = [] |
| | discrim_losses = [] |
| | losses_in_time = [] |
| | perplexities = [] |
| |
|
| | for i in range(num_samples): |
| | pert_gen_tok_text, discrim_loss, loss_in_time, perplexity = generate_text_pplm( |
| | model=model, |
| | tokenizer=tokenizer, |
| | context=context, |
| | device=device, |
| | perturb=True, |
| | bow_indices=bow_indices, |
| | classifier=classifier, |
| | class_label=class_id, |
| | loss_type=loss_type, |
| | length=length, |
| | stepsize=stepsize, |
| | temperature=temperature, |
| | top_k=top_k, |
| | sample=sample, |
| | num_iterations=num_iterations, |
| | grad_length=grad_length, |
| | horizon_length=horizon_length, |
| | window_length=window_length, |
| | decay=decay, |
| | gamma=gamma, |
| | gm_scale=gm_scale, |
| | kl_scale=kl_scale, |
| | verbosity_level=verbosity_level, |
| | stop_eot=stop_eot |
| | ) |
| | pert_gen_tok_texts.append(pert_gen_tok_text) |
| | if classifier is not None: |
| | discrim_losses.append(discrim_loss.data.cpu().numpy()) |
| | losses_in_time.append(loss_in_time) |
| | perplexities.append(perplexity) |
| |
|
| | if device == 'cuda': |
| | torch.cuda.empty_cache() |
| |
|
| | return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time, perplexities |
| |
|
| |
|
| | def generate_text_pplm( |
| | model, |
| | tokenizer, |
| | context=None, |
| | past=None, |
| | device="cuda", |
| | perturb=True, |
| | bow_indices=None, |
| | classifier=None, |
| | class_label=None, |
| | loss_type=0, |
| | length=100, |
| | stepsize=0.02, |
| | temperature=1.0, |
| | top_k=10, |
| | sample=True, |
| | num_iterations=3, |
| | grad_length=10000, |
| | horizon_length=1, |
| | window_length=0, |
| | decay=False, |
| | gamma=1.5, |
| | gm_scale=0.9, |
| | kl_scale=0.01, |
| | verbosity_level=REGULAR, |
| | stop_eot=False |
| | ): |
| | output_so_far = None |
| | if context: |
| | context_t = torch.tensor(context, device=device, dtype=torch.long) |
| | while len(context_t.shape) < 2: |
| | context_t = context_t.unsqueeze(0) |
| | output_so_far = context_t |
| |
|
| | |
| | one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, |
| | device) |
| |
|
| | grad_norms = None |
| | last = None |
| | unpert_discrim_loss = 0 |
| | loss_in_time = [] |
| |
|
| | if verbosity_level >= VERBOSE: |
| | range_func = trange(length, ascii=True) |
| | else: |
| | range_func = range(length) |
| | |
| | pert_total_prob = 1 |
| | pert_times = 0 |
| | for i in range_func: |
| |
|
| | |
| | |
| |
|
| | |
| | if past is None and output_so_far is not None: |
| | last = output_so_far[:, -1:] |
| | if output_so_far.shape[1] > 1: |
| | _, past, _ = model(output_so_far[:, :-1]) |
| | |
| | unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) |
| | unpert_last_hidden = unpert_all_hidden[-1] |
| |
|
| | |
| | if i >= grad_length: |
| | current_stepsize = stepsize * 0 |
| | else: |
| | current_stepsize = stepsize |
| |
|
| | |
| | if not perturb or num_iterations == 0: |
| | pert_past = past |
| |
|
| | else: |
| | accumulated_hidden = unpert_last_hidden[:, :-1, :] |
| | accumulated_hidden = torch.sum(accumulated_hidden, dim=1) |
| |
|
| | if past is not None: |
| | pert_past, _, grad_norms, loss_this_iter = perturb_past( |
| | past, |
| | model, |
| | last, |
| | unpert_past=unpert_past, |
| | unpert_logits=unpert_logits, |
| | accumulated_hidden=accumulated_hidden, |
| | grad_norms=grad_norms, |
| | stepsize=current_stepsize, |
| | one_hot_bows_vectors=one_hot_bows_vectors, |
| | classifier=classifier, |
| | class_label=class_label, |
| | loss_type=loss_type, |
| | num_iterations=num_iterations, |
| | horizon_length=horizon_length, |
| | window_length=window_length, |
| | decay=decay, |
| | gamma=gamma, |
| | kl_scale=kl_scale, |
| | device=device, |
| | verbosity_level=verbosity_level |
| | ) |
| | loss_in_time.append(loss_this_iter) |
| | else: |
| | pert_past = past |
| |
|
| | pert_logits, past, pert_all_hidden = model(last, past_key_values=pert_past) |
| | pert_logits = pert_logits[:, -1, :] / temperature |
| | pert_probs = F.softmax(pert_logits, dim=-1) |
| |
|
| | if classifier is not None: |
| | ce_loss = torch.nn.CrossEntropyLoss() |
| | prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) |
| | label = torch.tensor([class_label], device=device, |
| | dtype=torch.long) |
| | unpert_discrim_loss = ce_loss(prediction, label) |
| | if verbosity_level >= VERBOSE: |
| | print( |
| | "unperturbed discrim loss", |
| | unpert_discrim_loss.data.cpu().numpy() |
| | ) |
| | else: |
| | unpert_discrim_loss = 0 |
| |
|
| | |
| | if perturb: |
| |
|
| | unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) |
| |
|
| | pert_probs = ((pert_probs ** gm_scale) * ( |
| | unpert_probs ** (1 - gm_scale))) |
| | pert_probs = top_k_filter(pert_probs, k=top_k, |
| | probs=True) |
| |
|
| | |
| | if torch.sum(pert_probs) <= 1: |
| | pert_probs = pert_probs / torch.sum(pert_probs) |
| |
|
| | else: |
| | pert_logits = top_k_filter(pert_logits, k=top_k) |
| | pert_probs = F.softmax(pert_logits, dim=-1) |
| |
|
| | |
| | if sample: |
| | last = torch.multinomial(pert_probs, num_samples=1) |
| | pert_total_prob = pert_total_prob * pert_probs[0][last[0][0]] |
| | else: |
| | _, last = torch.topk(pert_probs, k=1, dim=-1) |
| |
|
| | |
| | output_so_far = ( |
| | last if output_so_far is None |
| | else torch.cat((output_so_far, last), dim=1) |
| | ) |
| | if verbosity_level >= REGULAR: |
| | print(tokenizer.decode(output_so_far.tolist()[0])) |
| | pert_times += 1 |
| | if last[0][0] == 50256 and stop_eot: |
| | break |
| | perplexity = (1/pert_total_prob)**(1/pert_times) |
| | return output_so_far, unpert_discrim_loss, loss_in_time, perplexity |
| |
|
| | def get_perplexity( |
| | model, |
| | tokenizer, |
| | past=None, |
| | device="cuda", |
| | perturb=True, |
| | bow_indices=None, |
| | classifier=None, |
| | class_label=None, |
| | loss_type=0, |
| | length=100, |
| | stepsize=0.02, |
| | temperature=1.0, |
| | top_k=10, |
| | sample=True, |
| | num_iterations=3, |
| | grad_length=10000, |
| | horizon_length=1, |
| | window_length=0, |
| | decay=False, |
| | gamma=1.5, |
| | gm_scale=0.9, |
| | kl_scale=0.01, |
| | verbosity_level=REGULAR, |
| | stop_eot=False, |
| | test_text=None |
| | ): |
| | if test_text == None: |
| | print("No text to test") |
| | return |
| | test_text = torch.tensor(test_text, device=device, dtype=torch.long) |
| | while len(test_text.shape) < 2: |
| | test_text = test_text.unsqueeze(0) |
| | eos_pos = (test_text == 50256).nonzero(as_tuple=True)[1] |
| | start = int(eos_pos[eos_pos.size(dim=0)-2]+1) |
| | end = int(eos_pos[eos_pos.size(dim=0)-1]) |
| | pert_total_prob = 1 |
| | pert_times = 0 |
| | error_occured = False |
| |
|
| | |
| | one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, |
| | device) |
| |
|
| | grad_norms = None |
| | last = None |
| | unpert_discrim_loss = 0 |
| | loss_in_time = [] |
| |
|
| | for i in range(start, end): |
| | output_so_far = test_text[:][:i] |
| | cur_word = str(tokenizer.decode([test_text[0][i]])).lower().strip() |
| | last_word = str(tokenizer.decode([test_text[0][i-1]])).lower().strip() |
| | |
| | |
| | |
| |
|
| | |
| | if past is None and output_so_far is not None: |
| | last = output_so_far[:,-1:] |
| | _, past, _ = model(output_so_far[:,:-1]) |
| |
|
| | unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) |
| | unpert_last_hidden = unpert_all_hidden[-1] |
| |
|
| | |
| | if i >= grad_length: |
| | current_stepsize = stepsize * 0 |
| | else: |
| | current_stepsize = stepsize |
| |
|
| | |
| | if not perturb or num_iterations == 0: |
| | pert_past = past |
| |
|
| | else: |
| | accumulated_hidden = unpert_last_hidden[:, :-1, :] |
| | accumulated_hidden = torch.sum(accumulated_hidden, dim=1) |
| |
|
| | if past is not None: |
| | pert_past, _, grad_norms, loss_this_iter = perturb_past( |
| | past, |
| | model, |
| | last, |
| | unpert_past=unpert_past, |
| | unpert_logits=unpert_logits, |
| | accumulated_hidden=accumulated_hidden, |
| | grad_norms=grad_norms, |
| | stepsize=current_stepsize, |
| | one_hot_bows_vectors=one_hot_bows_vectors, |
| | classifier=classifier, |
| | class_label=class_label, |
| | loss_type=loss_type, |
| | num_iterations=num_iterations, |
| | horizon_length=horizon_length, |
| | window_length=window_length, |
| | decay=decay, |
| | gamma=gamma, |
| | kl_scale=kl_scale, |
| | device=device, |
| | verbosity_level=verbosity_level |
| | ) |
| | loss_in_time.append(loss_this_iter) |
| | else: |
| | pert_past = past |
| |
|
| | pert_logits, past, pert_all_hidden = model(last, past_key_values=pert_past) |
| | pert_logits = pert_logits[:, -1, :] / temperature |
| | pert_probs = F.softmax(pert_logits, dim=-1) |
| |
|
| | if classifier is not None: |
| | ce_loss = torch.nn.CrossEntropyLoss() |
| | prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) |
| | label = torch.tensor([class_label], device=device, |
| | dtype=torch.long) |
| | unpert_discrim_loss = ce_loss(prediction, label) |
| | if verbosity_level >= VERBOSE: |
| | print( |
| | "unperturbed discrim loss", |
| | unpert_discrim_loss.data.cpu().numpy() |
| | ) |
| | else: |
| | unpert_discrim_loss = 0 |
| |
|
| | |
| | if perturb: |
| |
|
| | unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) |
| |
|
| | pert_probs = ((pert_probs ** gm_scale) * ( |
| | unpert_probs ** (1 - gm_scale))) |
| | pert_probs = top_k_filter(pert_probs, k=top_k, |
| | probs=True) |
| |
|
| | |
| | if torch.sum(pert_probs) <= 1: |
| | pert_probs = pert_probs / torch.sum(pert_probs) |
| |
|
| | else: |
| | pert_logits = top_k_filter(pert_logits, k=top_k) |
| | pert_probs = F.softmax(pert_logits, dim=-1) |
| |
|
| | |
| | if sample: |
| | last = torch.multinomial(pert_probs, num_samples=1) |
| | if (not cur_word in words_corpus.words()) or cur_word in corpus.names.words() or cur_word in corpus.stopwords.words(): |
| | pass |
| | else: |
| | if pert_probs[0][test_text[0][i]] != 0: |
| | pert_total_prob = pert_total_prob * pert_probs[0][test_text[0][i]] |
| | pert_times += 1 |
| | else: |
| | error_occured = True |
| | else: |
| | _, last = torch.topk(pert_probs, k=1, dim=-1) |
| |
|
| | |
| | |
| | output_so_far = ( |
| | last if output_so_far is None |
| | else torch.cat((output_so_far, last), dim=1) |
| | ) |
| | if last[0][0] == 50256 and stop_eot: |
| | break |
| | if pert_times != 0: |
| | perplexity = (1/pert_total_prob)**(1/pert_times) |
| | else: |
| | perplexity = -2 if error_occured else -1 |
| | return perplexity |
| |
|
| |
|
| | def set_generic_model_params(discrim_weights, discrim_meta): |
| | if discrim_weights is None: |
| | raise ValueError('When using a generic discriminator, ' |
| | 'discrim_weights need to be specified') |
| | if discrim_meta is None: |
| | raise ValueError('When using a generic discriminator, ' |
| | 'discrim_meta need to be specified') |
| |
|
| | with open(discrim_meta, 'r') as discrim_meta_file: |
| | meta = json.load(discrim_meta_file) |
| | meta['path'] = discrim_weights |
| | DISCRIMINATOR_MODELS_PARAMS['generic'] = meta |
| |
|
| |
|
| | def run_pplm_example( |
| | pretrained_model="gpt2-medium", |
| | cond_text="", |
| | uncond=False, |
| | num_samples=1, |
| | bag_of_words=None, |
| | discrim=None, |
| | discrim_weights=None, |
| | discrim_meta=None, |
| | class_label=-1, |
| | length=100, |
| | stepsize=0.02, |
| | temperature=1.0, |
| | top_k=10, |
| | sample=True, |
| | num_iterations=3, |
| | grad_length=10000, |
| | horizon_length=1, |
| | window_length=0, |
| | decay=False, |
| | gamma=1.5, |
| | gm_scale=0.9, |
| | kl_scale=0.01, |
| | seed=0, |
| | no_cuda=False, |
| | colorama=False, |
| | verbosity='regular', |
| | fp=None, |
| | model_fp=None, |
| | calc_perplexity=False, |
| | is_deep=False, |
| | is_deeper=False, |
| | stop_eot=False |
| | ): |
| | |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| |
|
| | |
| | verbosity_level = VERBOSITY_LEVELS.get(verbosity.lower(), REGULAR) |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" |
| |
|
| | if discrim == 'generic': |
| | set_generic_model_params(discrim_weights, discrim_meta) |
| |
|
| | if discrim is not None: |
| | discriminator_pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ |
| | "pretrained_model" |
| | ] |
| | if pretrained_model != discriminator_pretrained_model: |
| | pretrained_model = discriminator_pretrained_model |
| | if verbosity_level >= REGULAR: |
| | print("discrim = {}, pretrained_model set " |
| | "to discriminator's = {}".format(discrim, pretrained_model)) |
| |
|
| | |
| | model = GPT2LMHeadModel.from_pretrained( |
| | pretrained_model, |
| | output_hidden_states=True |
| | ) |
| | if model_fp != None: |
| | try: |
| | model.load_state_dict(torch.load(model_fp)) |
| | except: |
| | print("Can't load local model") |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) |
| |
|
| | |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | if uncond: |
| | tokenized_cond_text = tokenizer.encode( |
| | [tokenizer.bos_token], |
| | add_special_tokens=False |
| | ) |
| | else: |
| | raw_text = cond_text |
| | while not raw_text: |
| | print("Did you forget to add `--cond_text`? ") |
| | raw_text = input("Model prompt >>> ") |
| | tokenized_cond_text = tokenizer.encode( |
| | tokenizer.bos_token + raw_text, |
| | add_special_tokens=False |
| | ) |
| | |
| | print("= Prefix of sentence =") |
| | print(tokenizer.decode(tokenized_cond_text)) |
| | print() |
| |
|
| | |
| |
|
| | |
| | |
| | unpert_gen_tok_text, pert_gen_tok_texts, _, _, perplexities = full_text_generation( |
| | model=model, |
| | tokenizer=tokenizer, |
| | context=tokenized_cond_text, |
| | device=device, |
| | num_samples=num_samples, |
| | bag_of_words=bag_of_words, |
| | discrim=discrim, |
| | class_label=class_label, |
| | length=length, |
| | stepsize=stepsize, |
| | temperature=temperature, |
| | top_k=top_k, |
| | sample=sample, |
| | num_iterations=num_iterations, |
| | grad_length=grad_length, |
| | horizon_length=horizon_length, |
| | window_length=window_length, |
| | decay=decay, |
| | gamma=gamma, |
| | gm_scale=gm_scale, |
| | kl_scale=kl_scale, |
| | verbosity_level=verbosity_level, |
| | fp=fp, |
| | is_deep=is_deep, |
| | is_deeper=is_deeper, |
| | stop_eot=stop_eot |
| | ) |
| |
|
| | |
| | unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0]) |
| |
|
| | if verbosity_level >= REGULAR: |
| | print("=" * 80) |
| | print("= Unperturbed generated text =") |
| | print(unpert_gen_text) |
| | print() |
| |
|
| | generated_texts = [] |
| |
|
| | bow_word_ids = set() |
| | if bag_of_words and colorama: |
| | bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), |
| | tokenizer) |
| | for single_bow_list in bow_indices: |
| | |
| | filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) |
| | |
| | bow_word_ids.update(w[0] for w in filtered) |
| |
|
| | |
| | for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): |
| | try: |
| | |
| | if colorama: |
| | import colorama |
| |
|
| | pert_gen_text = '' |
| | for word_id in pert_gen_tok_text.tolist()[0]: |
| | if word_id in bow_word_ids: |
| | pert_gen_text += '{}{}{}'.format( |
| | colorama.Fore.RED, |
| | tokenizer.decode([word_id]), |
| | colorama.Style.RESET_ALL |
| | ) |
| | else: |
| | pert_gen_text += tokenizer.decode([word_id]) |
| | else: |
| | pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0]) |
| |
|
| | print("= Perturbed generated text {} =".format(i + 1)) |
| | print(pert_gen_text) |
| | if calc_perplexity: |
| | print("Perplexity:", perplexities[i]) |
| | print() |
| | except: |
| | pass |
| |
|
| | |
| | generated_texts.append( |
| | (tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text) |
| | ) |
| |
|
| | return |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--pretrained_model", |
| | "-M", |
| | type=str, |
| | default="gpt2-medium", |
| | help="pretrained model name or path to local checkpoint", |
| | ) |
| | parser.add_argument( |
| | "--cond_text", type=str, default="The lake", |
| | help="Prefix texts to condition on" |
| | ) |
| | parser.add_argument( |
| | "--uncond", action="store_true", |
| | help="Generate from end-of-text as prefix" |
| | ) |
| | parser.add_argument( |
| | "--num_samples", |
| | type=int, |
| | default=1, |
| | help="Number of samples to generate from the modified latents", |
| | ) |
| | parser.add_argument( |
| | "--bag_of_words", |
| | "-B", |
| | type=str, |
| | default=None, |
| | help="Bags of words used for PPLM-BoW. " |
| | "Either a BOW id (see list in code) or a filepath. " |
| | "Multiple BoWs separated by ;", |
| | ) |
| | parser.add_argument( |
| | "--discrim", |
| | "-D", |
| | type=str, |
| | default=None, |
| | choices=("clickbait", "sentiment", "toxicity", "generic", "3_PerSoothe", |
| | "3_PerSoothe_eot", "3_PerSoothe_lrg", "3_PerSoothe_med", "2_PerSoothe_lrg", "2_PerSoothe_med"), |
| | help="Discriminator to use", |
| | ) |
| | parser.add_argument('--discrim_weights', type=str, default=None, |
| | help='Weights for the generic discriminator') |
| | parser.add_argument('--discrim_meta', type=str, default=None, |
| | help='Meta information for the generic discriminator') |
| | parser.add_argument( |
| | "--class_label", |
| | type=int, |
| | default=-1, |
| | help="Class label used for the discriminator", |
| | ) |
| | parser.add_argument("--length", type=int, default=100) |
| | parser.add_argument("--stepsize", type=float, default=0.02) |
| | parser.add_argument("--temperature", type=float, default=1.0) |
| | parser.add_argument("--top_k", type=int, default=10) |
| | parser.add_argument( |
| | "--sample", action="store_true", |
| | help="Generate from end-of-text as prefix" |
| | ) |
| | parser.add_argument("--num_iterations", type=int, default=3) |
| | parser.add_argument("--grad_length", type=int, default=10000) |
| | parser.add_argument( |
| | "--window_length", |
| | type=int, |
| | default=0, |
| | help="Length of past which is being optimized; " |
| | "0 corresponds to infinite window length", |
| | ) |
| | parser.add_argument( |
| | "--horizon_length", |
| | type=int, |
| | default=1, |
| | help="Length of future to optimize over", |
| | ) |
| | parser.add_argument("--decay", action="store_true", |
| | help="whether to decay or not") |
| | parser.add_argument("--gamma", type=float, default=1.5) |
| | parser.add_argument("--gm_scale", type=float, default=0.9) |
| | parser.add_argument("--kl_scale", type=float, default=0.01) |
| | parser.add_argument("--seed", type=int, default=0) |
| | parser.add_argument("--no_cuda", action="store_true", help="no cuda") |
| | parser.add_argument("--colorama", action="store_true", |
| | help="colors keywords") |
| | parser.add_argument("--verbosity", type=str, default="very_verbose", |
| | choices=( |
| | "quiet", "regular", "verbose", "very_verbose"), |
| | help="verbosiry level") |
| | parser.add_argument("--fp", type=str, default="") |
| | parser.add_argument("--model_fp", type=str, default="") |
| | parser.add_argument("--calc_perplexity", action="store_true", help="calculate perplexity") |
| | parser.add_argument("--is_deep", action="store_true", |
| | help="whether to use deep classifier") |
| | parser.add_argument("--is_deeper", action="store_true", |
| | help="whether to use deep classifier") |
| | parser.add_argument("--stop_eot", action="store_true", |
| | help="whether to stop at eot token") |
| |
|
| | args = parser.parse_args() |
| | run_pplm_example(**vars(args)) |
| |
|