Commit 586ba564 authored by friebolin's avatar friebolin
Browse files

Update inference

parent ef32f86f
Loading
Loading
Loading
Loading
+14 −12
Original line number Diff line number Diff line
import argparse
import torch
import preprocess
import train
import models
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig,  BertConfig, BertPreTrainedModel, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
import re 
import models
import train
from torch.utils.data import DataLoader, RandomSampler


# Get user input
print("Enter a sentence: ")
sentence = input()
sentence = sentence.split()
@@ -16,22 +18,21 @@ target_pos = input()
print("Enter the label: 0 for literal, 1 for non-literal")
label = int(input())


data_sample = {"sentence": sentence, "pos": target_pos, "label": label}
print(data_sample)


filepath = "./saved_models/bert_baseline.pt"
model=models.BertForWordClassification.from_pretrained("bert-base-uncased")
#tokenizer=AutoTokenizer.from_pretrained(args.architecture)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.load_state_dict(torch.load(filepath))
model.eval()		#loads saved model
model=models.BertForWordClassification.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_baseline.pth"
model = torch.load(model_path, map_location=device)

train_dataset = [{"sentence": ["Yet", "how", "many", "times", "has", "America", "sided", "with", "Israeli", "aggression", "against", "the", "people", "of", "Palestine?"], "pos": [5, 6], "label": 1}]
train_sampler = RandomSampler(train_dataset)
train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=1)
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")

train_sampler = RandomSampler(data_sample)
train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1)

for batch in train_dataloader:
	inputs = {'input_ids': batch[0],
@@ -44,4 +45,5 @@ for batch in train_dataloader:
	start_positions=batch[3]
	end_positions=batch[4]
	outputs=model(**inputs)

	print("Outputs: ", outputs)