Commit 6ed4931e authored by schaper's avatar schaper
Browse files

Update SBERT_Model.py with data frame as parameter

parent a1e1fd8c
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -8,14 +8,17 @@ from random import randrange

class SBERT_Model:
    def __init__(self, name, filepath, dataset, probing=False):
        """
        param: dataset must be pandas dataframe
        """
        self.name = name
        self.filepath = filepath
        self.dataset = dataset
        self.model = SentenceTransformer(filepath)
        self.sentences1 = [s['sentence1'] for s in self.dataset._examples]
        self.sentences2 = [s['sentence2'] for s in self.dataset._examples]
        self.sentences1 = dataset['sentence1'].tolist()
        self.sentences2 = dataset['sentence2'].tolist()
    
        self.labels = [s['label'] for s in self.dataset._examples]
        self.labels = dataset['label'].tolist()
        self.embeddings1 = self.get_embeddings(self.sentences1)
        self.embeddings2 = self.get_embeddings(self.sentences2)
        self.cosine_scores = self.get_cosine_scores()