Commit 3017363c authored by friebolin's avatar friebolin
Browse files

Add backtranslation code

parent bafe7145
Loading
Loading
Loading
Loading

Code/backtranslate.py

0 → 100644
+97 −0
Original line number Diff line number Diff line
""" Backtranslation: Translate original sentences using Fairseq (https://github.com/facebookresearch/fairseq/blob/main/examples/translation/README.md)
to create 5 paraphrases """

import numpy as np
import pandas as pd
import json
import torch
import tqdm as notebook_tqdm
import os 

# Define backtranslation temperature parameter
#temperature = 0.8
temperature = 1.2

def load_data_set(file_name):
    with open(file_name, "r") as file:
        data = file.read()
        return json.loads(data)

def load_data_sets(data_dir):
    semeval_train_data = load_data_set(os.path.join(data_dir, "semeval_train.txt"))
    companies_train_data = load_data_set(os.path.join(data_dir, "companies_train.txt"))
    relocar_train_data = load_data_set(os.path.join(data_dir, "relocar_train.txt"))
    return semeval_train_data, companies_train_data, relocar_train_data

data_dir = "./data" 
semeval_train_data, companies_train_data, relocar_train_data = load_data_sets(data_dir)


#Load Fairseq transformers trained on WMT'19
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe')

en2de.eval()  # disable dropout

en2de.cuda() # move model to GPU for faster translation
de2en.cuda()


# Define helper functions
def extract_target(dp):
    start = dp["pos"][0]
    end = dp["pos"][1]
    target = dp["sentence"][start:end]
    target = " ".join(target).lower()
    return target

def join_tokens(dp):
    sent = " ".join(dp)
    sent = sent.replace(' ,', ',')
    sent = sent.replace(' .', '.')
    sent = sent.replace(' !', '!')
    sent = sent.replace(' ?', '?')
    sent = sent.replace(' :', ':')
    sent = sent.replace(' ;', ';')
    sent = sent.replace(" '", "'")
    sent = sent.replace(' "', '"')
    sent = sent.replace('` ', '')
    return sent

# Apply backtranslation
datasets = [semeval_train_data, companies_train_data, relocar_train_data]
path_names = [f"semeval_loc_with_paraphrases_temp{temperature}", f"semeval_org_with_paraphrases_temp{temperature}", f"relocar_with_paraphrases_temp{temperature}"]

for dataset, name in zip(datasets, path_names):
    # Preprocess original sentences
    data = pd.DataFrame.from_dict(dataset, orient='columns')

    # Extract targets & add as new column
    target_words = [extract_target(row) for index, row in data.iterrows()]
    data["targets"] = target_words

    # Join split original sentence to single sentence
    joined_sents = [join_tokens(sent) for sent in data["sentence"]]
    data["joined_sents"] = joined_sents

    # Paraphrase joined original sentences
    def paraphrase(sent, temperature):
        en_encode = en2de.encode(sent)
        outputs = en2de.generate(en_encode, sampling=True, topp=0.5, temperature=temperature)
        nucleus = [en2de.decode(x['tokens']) for x in outputs]

        multi_paraphrases = []      #backtranslate 4 times
        for sentence in nucleus:
            de_1 = de2en.translate(sentence)
            de_2 = en2de.translate(de_1)
            paraphrase = de2en.translate(de_2)
            multi_paraphrases.append(paraphrase)

        return multi_paraphrases

    paraphrases = [paraphrase(sent, temperature) for sent in data["joined_sents"]]
    data["paraphrases"] = paraphrases

    data_path = os.path.join(data_dir, name)
    data.to_csv(data_path)
    print(f"Done: paraphrases generated and saved for {name}.")