Commit 2157877f authored by schaper's avatar schaper
Browse files

Add SBERT_Model class and evaluation of zero-shot model

parent 00d145c5
Loading
Loading
Loading
Loading
+50 −9
Original line number Diff line number Diff line
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, util
from scipy.stats import spearmanr, pearsonr
from sklearn import metrics

dataset = load_dataset('paws', 'labeled_final', split='test')
class SBERT_Model:
    def __init__(self, name, filepath, dataset):
        self.name = name
        self.filepath = filepath
        self.dataset = dataset
        self.model = SentenceTransformer(filepath)
        self.sentences1 = dataset.head(100)['sentence1'].tolist()
        self.sentences2 = dataset.head(100)['sentence2'].tolist()
        self.labels = dataset.head(100)['label'].tolist()
        self.embeddings1 = self.get_embeddings(self.sentences1)
        self.embeddings2 = self.get_embeddings(self.sentences2)
        self.cosine_scores = self.get_cosine_scores()
        self.preds = self.get_preds()

dataset.set_format(type='pandas')
    def get_embeddings(self, sentences):
        return self.model.encode(sentences, convert_to_tensor=True)

    def get_preds(self):
        """This method extracts the scores for similarity between sentence pairs
        Cosine_scores have scores for similarity between all of the sentences,
        but we only need similarity between each sentence pair.
        """
        preds = []
        for i in range(len(self.cosine_scores[0])):
            preds.append(self.cosine_scores[i][i])
        return preds

    def get_cosine_scores(self):
        return util.pytorch_cos_sim(self.embeddings1, self.embeddings2)

    def get_pearson(self):
        return pearsonr(self.labels, self.preds)[0]

foo = dataset[:]
    def get_spearman(self):
        return spearmanr(self.labels, self.preds)[0]

model = SentenceTransformer('models/stsb-bert-large/')
    def get_MSE(self):
        return metrics.mean_squared_error(self.labels, self.preds)

    def print_statistics(self):
        print(f"{self.name}: MSE:{self.get_MSE()}; Pearson:{self.get_pearson()}; Spearman:{self.get_spearman()}")


dataset = load_dataset('paws', 'labeled_final', split='test')
dataset.set_format(type='pandas')
dataset = dataset[:]

# predict dataset instances with model and evaluate with confusion matrix
# calculate spearman rank correlation between cosine-similarity of the sentence embeddings and the gold labels
# (or if too hard pearson correlation)
zero_model = SBERT_Model("Zero Model", 'models/nli-bert-large/', dataset)
ft_model = SBERT_Model("Fine-tuned Model", 'models/stsb-bert-large/', dataset)

# use code from evaluation_stsbenchmark.py
zero_model.print_statistics()
ft_model.print_statistics()