Commit 7d1ec45c authored by kreuzer's avatar kreuzer
Browse files

Aktualisieren main_CrossEntropy.py

parent 6c4b0f7a
Loading
Loading
Loading
Loading
+60 −0
Original line number Diff line number Diff line
import torch
from torch import nn
import numpy as np
from structures import *
from models import *

dataset = torch.load("dataset.data")

# hyperparameters
epochs=2
batch_size=20
learning_rate=0.001

train_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, shuffle=True, collate_fn=lambda x:x)

m = SummarisationModelWithCrossEntropyLoss()

since = time.time()
val_rouge_history = []

best_rouge = 0.0
best_model_wts = copy.deepcopy(m.state_dict())

for epoch in range(epochs):
    
    print()
    print('Epoch {}/{}'.format(epoch+1, epochs))
    print('-' * 10)

    # train phase
    epoch_loss = m.training_epoch(train_dataloader)
    print('Train Loss: {:.4f}'.format(epoch_loss))
    
    # validation phase 
    epoch_rouge = m.validation(dataset.validation)
    val_rouge_history.append(epoch_rouge)    
    print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))

    # epoch completed, deep copy the best model sofar
    if epoch_rouge > best_rouge:
        best_rouge = epoch_rouge
        best_model_wts = copy.deepcopy(m.state_dict())

# after training completed
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
# write val_rouge_history in file

# load best model weights
m.load_state_dict(best_model_wts)

# testing
since = time.time()

epoch_rouge_1, epoch_rouge_2, epoch_rouge_l = m.test(dataset.test)
print('Test rouge_1: {:.4f} rouge_2: {:.4f} rouge_l: {:.4f} mean: {:.4f}'.format(epoch_rouge_1, epoch_rouge_2, epoch_rouge_l, (epoch_rouge_1+epoch_rouge_2+epoch_rouge_l)/3.0))

# after testing completed
time_elapsed = time.time() - since
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))