Commit 03e916e3 authored by Aileen Reichelt's avatar Aileen Reichelt
Browse files

Merge remote-tracking branch 'origin/master'

parents 093b9126 a5160f34
Loading
Loading
Loading
Loading
+207 −26
Original line number Diff line number Diff line
from sentence_transformers import SentenceTransformer, util

from scipy.stats import spearmanr, pearsonr
from sklearn import metrics
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
import umap
import matplotlib.pyplot as plt
from random import randrange
import numpy as np

import time
from datetime import datetime
import os
import copy
import spacy


class SBERT_Model:
    def __init__(self, name, filepath, dataset, probing=False):
@@ -17,33 +25,206 @@ class SBERT_Model:
        self.model = SentenceTransformer(filepath)
        self.sentences1 = dataset['sentence1'].tolist()
        self.sentences2 = dataset['sentence2'].tolist()
    
        self.labels = dataset['label'].tolist()
        self.embeddings1 = self.get_embeddings(self.sentences1)
        self.embeddings2 = self.get_embeddings(self.sentences2)
        self.cosine_scores = self.get_cosine_scores()
        self.cosine_scores = self.get_cosine_scores(self.embeddings1, self.embeddings2)
        self.preds = self.get_preds()

    def visualize_embeddings(self):
        sentences = self.sentences1 + self.sentences2
    def sentence_clusters(self):
        """
        This method prints out clusters of sentences with high similarity.
        """
        corpus = self.sentences1 + self.sentences2
        corpus_embeddings = self.get_embeddings(corpus)

        # Normalize the embeddings to unit length
        corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)

        # Perform kmean clustering
        clustering_model = AgglomerativeClustering(n_clusters=None, distance_threshold=1.5) #, affinity='cosine', linkage='average', distance_threshold=0.4)
        clustering_model.fit(corpus_embeddings)
        cluster_assignment = clustering_model.labels_

        clustered_sentences = {}
        for sentence_id, cluster_id in enumerate(cluster_assignment):
            if cluster_id not in clustered_sentences:
                clustered_sentences[cluster_id] = []

            clustered_sentences[cluster_id].append(corpus[sentence_id])

        for i, cluster in clustered_sentences.items():
            print("Cluster ", i+1)
            print(cluster)
            print("")

    def visualize_embeddings(self, sentences, highlights=[], figsize=(80, 200)):
        """
        Parameters
        ----------
        sentences: list of strings
            list of sentences that will be visualized and colored in blue
        highlights: list of strings or a string
            list of sentences that will be visualized and colored in red           
        figsize: tuple
            size of the final image. The default is set to (80,200),
            which is 8000x20000 pixels and is appropriate for
            the stsb test dataset (2757 data points).

        An example of how the method can be used is provided 
        in sentence_similarity.py.
        Due to long loading times of the model, it is recommended 
        to use this method in the interactive mode: 
        python -i sentence_similarity.py
        """
        t1 = time.time()
        #Convert into list if it is a string
        if isinstance(highlights, str):
            highlights = [highlights]
        if len(highlights) == 1:
            highlights_embs = self.get_embeddings(highlights).reshape(1, -1)
        elif len(highlights) > 1:
            highlights_embs = self.get_embeddings(highlights)
        embs = self.get_embeddings(sentences)
        pca = PCA(n_components=2)
        X = pca.fit_transform(embs_together)
        plt.figure(figsize=(20,10))
        plt.scatter(X[:, 0], X[:, 1])
        for x, y in X:
            label = "{:.2f}".format(y)
            name = i
            label = f"{name}\n({x:.2f},{y:.2f})"

            plt.annotate(label, # this is the text
                        (x,y), # this is the point to label
                        textcoords="offset points", # how to position the text
                        xytext=(0,10), # distance from text to points (x,y)
                        ha='center') # horizontal alignment can be left, right or center
            i = i + 1
        plt.savefig(f'{randrange(0,10000)}plot.png')
        umap_model = umap.UMAP(n_neighbors=15, n_components=2, min_dist=0.5, spread=2, metric='cosine', random_state=42).fit(embs)
        umap_data_transformed = umap_model.transform(embs)

        plt.figure(figsize=(80,200))
        folder = "outputs"
        plot_name = f"plot-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
        plot_path = os.path.join(folder,plot_name)
        legend_name = f'legend-{plot_name[:-4]}.txt'
        legend_path = os.path.join(folder,legend_name)
        with open(legend_path, 'w') as reader:
            for i, (x, y) in enumerate(umap_data_transformed):
                name = sentences[i]
                reader.write(f"{i}: {name} ({x:.2f},{y:.2f})\n")
                label = i
                plt.plot(x, y, 'bo')
                plt.text(x, y, label)
            if highlights:
                highlights_transformed = umap_model.transform(highlights_embs)
                for i, (x, y) in enumerate(highlights_transformed):
                    name = highlights[i]
                    label = f"{i}: {name}\n({x:.2f},{y:.2f})"
                    reader.write(f"Highlights: {i}: {name} ({x:.2f},{y:.2f})\n")
                    plt.plot(x, y, 'ro', markersize=30)
                    plt.text(x, y, label)
        plt.savefig(plot_path)
        t2 = time.time()
        print(f"Plot saved to {plot_path}. Time it took to generate: {t2-t1:.2f} seconds")
        print(f'{legend_path} stores the information about the data points on the plot.')

    def visualize_preds(self):
        print(f"Average cosine similarity: {np.average(self.preds)}; Standard deviation: {np.std(self.preds)}")
        fig, ax = plt.subplots(1, 1, figsize=(8, 5), tight_layout=True)
        ax.hist(self.preds, bins=20)
        plt.xlim([0, 1])
        ax.set_title(
            f"{self.name} - histogram " + r"$\mu$" + f"={round(np.average(self.preds), 3)}, " + r"$\sigma$" f"={round(np.std(self.preds), 3)}")

        plt.ylabel('Frequency')
        plt.xlabel('Cosine similarity between sentence pair')
        plt.savefig(f'{self.name}_histogram.png', bbox_inches='tight')

    def get_word_importance(self,sentence_pair, mask="[MASK]"):
        nlp = spacy.load("en_core_web_sm")
        sentence1 = nlp(sentence_pair[0])
        sentence2 = nlp(sentence_pair[1])
        unmasked_embs = self.get_embeddings([sentence_pair[0], sentence_pair[1]])
        unmasked_similarity = self.get_cosine_scores(unmasked_embs[0], unmasked_embs[1])
        masked_pairs = []
        for i in range(len(sentence1)):
            masked_sentence = []
            for y in range(len(sentence1)):
                if y == i:
                    masked_sentence.append(mask)
                    masked_word = sentence1[y].text
                else:
                    masked_sentence.append(sentence1[y].text)
            masked_sentence = " ".join(masked_sentence)
            masked_pair = {"masked word": masked_word, "sentence1": masked_sentence, "sentence2": sentence_pair[1]}
            masked_pairs.append(masked_pair)
        for i in range(len(sentence2)):
            masked_sentence = []
            for y in range(len(sentence2)):
                if y == i:
                    masked_sentence.append(mask)
                    masked_word = sentence2[y].text
                else:
                    masked_sentence.append(sentence2[y].text)
            masked_sentence = " ".join(masked_sentence)
            masked_pair = {"masked word": masked_word, "sentence1": sentence_pair[0], "sentence2": masked_sentence}
            masked_pairs.append(masked_pair)
        scores = []
        for masked_pair in masked_pairs:
            embs = self.get_embeddings([masked_pair["sentence1"], masked_pair["sentence2"]])
            scores.append((masked_pair, float(abs(unmasked_similarity - self.get_cosine_scores(embs[0], embs[1])))))
#TODO: Find out more about the words that cause high deviation from unmasked similarity.
#         sorted_scores = copy.deepcopy(scores)
#         sorted_scores.sort(key = lambda x: x[1], reverse=True)
#         print(f"Pair: {sentence_pair}, Gold: {gold}")
#         for score in sorted_scores:
#             masked_pair = score[0]
#             print(f'1. {masked_pair["sentence1"]}\n \
# 2.{masked_pair["sentence2"]}\n \
# Deviation: {score[1]:.3f}, Masked word: {masked_pair["masked word"]}\n')
        html_output = []
        html_output.append('<div class="sentence">\n')
        for i, score in enumerate(scores):
            masked_pair = score[0]
            html_output.append(f'<div style="background-color: rgba(0, 160, 252, {score[1]:.2f});">{masked_pair["masked word"]}<br>{score[1]:.2f}</div>\n')
            if i == len(sentence1)-1:
                html_output.append('</div>\n<div class="sentence">\n')
        html_output.append('</div>\n')
        return html_output

    def get_html_word_importance(self, sentence_pairs, mask="[MASK]"):
        top = """<!doctype html>
<html lang="en">
<head>
  <meta charset="utf-8">
  <title>SBERT Word Importance</title>
  <meta name="description" content="Visualizing SBERT sentence pairs importance of words">
<style>
.sentence{
    display:flex;
}
.sentence > div{
    border: 1px solid gray;
    
    min-width: 100px;
    min-height: 20px;
    padding: 20px;
    align-items: center;
    text-align: center;
}

</style>
</head>

<body>
    <h1>SBERT Word Importance Visualization</h1>
"""

        bottom = "</body>\n</html>"

        folder = "outputs"
        html_name = f"heatmap-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.html"
        html_path = os.path.join(folder,html_name)

        with open(html_path, 'w') as reader:
            reader.write(top)
            reader.write(f'<h2>{self.name}</h2>\n')
            reader.write(f'<h3>Token used for masking:{mask}</h3>\n')
            for i, sentence_pair in enumerate(sentence_pairs):
                print(i)
                reader.write(f'<div id="{i}" class="sentence_pair">{i+1}.\n')
                sentence_divs = self.get_word_importance(sentence_pair, mask=mask)
                for sentence_div in sentence_divs:
                    reader.write(sentence_div)
                reader.write('</div>\n')
            reader.write(bottom)
            
    def get_embeddings(self, sentences):
        return self.model.encode(sentences, convert_to_tensor=True)
@@ -55,11 +236,11 @@ class SBERT_Model:
        """
        preds = []
        for i in range(len(self.cosine_scores[0])):
            preds.append(self.cosine_scores[i][i])
            preds.append(float(self.cosine_scores[i][i]))
        return preds

    def get_cosine_scores(self):
        return util.pytorch_cos_sim(self.embeddings1, self.embeddings2)
    def get_cosine_scores(self, emb1, emb2):
        return util.pytorch_cos_sim(emb1, emb2)

    def get_pearson(self):
        return pearsonr(self.labels, self.preds)[0]
+5 −6
Original line number Diff line number Diff line
import pandas as pd

class STSB_Dataset():

class STSB_Dataset:

    def __init__(self, path):

@@ -11,7 +12,6 @@ class STSB_Dataset():
        self.data = []
        self.labels = []


        for line in dataset:
            datapoint = line.split('\t')
            score = float(datapoint[4])
@@ -22,7 +22,6 @@ class STSB_Dataset():
            self.data.append([st1, st2])
            self.labels.append(score)

            
        # Store as a list of dicts, conforming to self.spec()
        self._examples = [{
          'sentence1': dp[0],
+39305 −0

File added.

Preview size limit exceeded, changes collapsed.

+39305 −0

File added.

Preview size limit exceeded, changes collapsed.

+2760 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading