Commit 4e988dcf authored by podrazka's avatar podrazka
Browse files

Specify data type to fix the problem with fine-tuning

parent fc9579b8
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -101,9 +101,9 @@ dev_samples = []
train_dataset = load_dataset('paws', 'labeled_final', split='train')
train_dataset.set_format(type='pandas')
train_dataset = train_dataset[:]

for index, row in train_dataset.iterrows():
    train_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=row['label']))
    x = torch.FloatTensor([row['label']])
    train_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=x))

dev_dataset = load_dataset('paws', 'labeled_final', split='validation')
dev_dataset.set_format(type='pandas')