Commit f29a0a6f authored by schaper's avatar schaper
Browse files

Add visualize_preds method for PAWS evaluation

parent d67dfb73
Loading
Loading
Loading
Loading
+16 −3
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import os
import copy
import spacy


class SBERT_Model:
    def __init__(self, name, filepath, dataset, probing=False):
        """
@@ -114,6 +115,18 @@ class SBERT_Model:
        print(f"Plot saved to {plot_path}. Time it took to generate: {t2-t1:.2f} seconds")
        print(f'{legend_path} stores the information about the data points on the plot.')

    def visualize_preds(self):
        print(f"Average cosine similarity: {np.average(self.preds)}; Standard deviation: {np.std(self.preds)}")
        fig, ax = plt.subplots(1, 1, figsize=(8, 5), tight_layout=True)
        ax.hist(self.preds, bins=20)
        plt.xlim([0, 1])
        ax.set_title(
            f"{self.name} - histogram " + r"$\mu$" + f"={round(np.average(self.preds), 3)}, " + r"$\sigma$" f"={round(np.std(self.preds), 3)}")

        plt.ylabel('Frequency')
        plt.xlabel('Cosine similarity between sentence pair')
        plt.savefig(f'{self.name}_histogram.png', bbox_inches='tight')

    def get_word_importance(self,sentence_pair, mask="[MASK]"):
        nlp = spacy.load("en_core_web_sm")
        sentence1 = nlp(sentence_pair[0])
@@ -223,7 +236,7 @@ class SBERT_Model:
        """
        preds = []
        for i in range(len(self.cosine_scores[0])):
            preds.append(self.cosine_scores[i][i])
            preds.append(float(self.cosine_scores[i][i]))
        return preds

    def get_cosine_scores(self, emb1, emb2):