Commit deedca33 authored by Aileen Reichelt's avatar Aileen Reichelt
Browse files

Implemented dataset separation into short, medium, length

Also implemented output of some length statistics.
parent 5bd3787a
Loading
Loading
Loading
Loading
+41 −12
Original line number Diff line number Diff line
from SBERT_Model import SBERT_Model
from datasets import load_dataset
import pandas
from STSB_Dataset import STSB_Dataset
import math

dataset = load_dataset('csv', data_files='data/stsbenchmark/sts-test.csv', split='test')
dataset.set_format(type='pandas')
print(dataset.head())

short_sentences = pandas.DataFrame
average_length_sentences = pandas.DataFrame
long_sentences = pandas.DataFrame
def add_length_column(dataset):
    sentence_1_lengths = dataset['sentence1'].apply(lambda sentence: len(sentence.split(" ")))
    sentence_2_lengths = dataset['sentence2'].apply(lambda sentence: len(sentence.split(" ")))
    dataset['total length'] = sentence_1_lengths
    dataset['total length'] += sentence_2_lengths
    return dataset

datasets = [short_sentences, average_length_sentences, long_sentences]

def get_length_distribution(dataset) -> dict:
    sentences = dataset['sentence1'].to_numpy() + dataset['sentence2'].to_numpy()
    lengths = dict()
    for sentence in sentences:
        sentence_length = len(sentence.split(" "))
        if sentence_length not in lengths:
            lengths[sentence_length] = 0
        lengths[sentence_length] += 1
    return lengths


if __name__ == "__main__":
    complete_stsb_dataset = STSB_Dataset('data/stsbenchmark/sts-test.csv').as_dataframe
    dataset_with_lengths = add_length_column(complete_stsb_dataset)

    average_length = math.ceil(dataset_with_lengths['total length'].mean())
    standard_deviation = math.ceil(dataset_with_lengths['total length'].std())
    print(f"Average length: {average_length}, standard deviation: {standard_deviation}")

    for length, occurrence in sorted(get_length_distribution(dataset_with_lengths).items()):
        print(f"Length: {length} | Occurrence: {occurrence}")

    short_sentences = dataset_with_lengths.loc[dataset_with_lengths['total length'] <= (average_length - standard_deviation)]
    medium_length_sentences = dataset_with_lengths.loc[
        (dataset_with_lengths['total length'] > (average_length - standard_deviation)) &
        (dataset_with_lengths['total length'] < (average_length + standard_deviation))]
    long_sentences = dataset_with_lengths.loc[dataset_with_lengths['total length'] >= (average_length + standard_deviation)]

    datasets = [short_sentences, medium_length_sentences, long_sentences]

    for d in datasets:
        ft_model = SBERT_Model("Fine-tuned Model", 'models/stsb-bert-large/', d)