Commit 3836f437 authored by wu's avatar wu
Browse files

Update models.py, main_Critic.py

parent 7d1ec45c
Loading
Loading
Loading
Loading

main_Critic.py

0 → 100644
+62 −0
Original line number Diff line number Diff line
# import




# hyperparameters
epochs=20
batch_size=20
learning_rate=0.001

train_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, shuffle=True, collate_fn = lambda x: x)

# load state dict from trainiing
model_actor_only = ActorOnlySummarisationModel()
model_actor_only.load_state_dict('model_actor_only_wts.pth')
model_actor_only.eval()

m = Critic(model_actor_only)


since = time.time()
val_loss_history = []

best_model_wts = copy.deepcopy(m.state_dict())

for e in range(epochs):

    print('Epoch {}/{}'.format(e, epochs - 1))
    print('-' * 10)

    # train phase
    epoch_loss = m._epoch(train_dataloader)  # collate_fn
    print('Train Loss: {:.4f}'.format(epoch_loss))
    
    # validation phase 
    val_epoch_loss = m._epoch(dataset.validation)
    val_loss_history.append(val_epoch_loss)    
    print('Validation Loss: {:.4f}'.format(val_epoch_loss))

    # epoch completed, deep copy the best model sofar
    if val_epoch_loss < best_loss:
        best_loss = val_epoch_loss
        best_model_wts = copy.deepcopy(m.state_dict())

# after training completed
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
# write val_rouge_history in file

# load best model weights
m.load_state_dict(best_model_wts)


# testing
since = time.time()
test_loss = m.test(dataset.test)
print('Test Loss: {:.4f}'.format(test_loss))

# after testing completed
time_elapsed = time.time() - since
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
+112 −30
Original line number Diff line number Diff line
@@ -221,18 +221,64 @@ class Critic(nn.Module):
            nn.functional.relu(self.layer_2(
            utils.gaussian(self.layer_1(double_document)))))))
    
    def _train(self, dataset, epochs=200, batch_size=20, learning_rate=0.001, shuffle=True, pos_samples=0.5): # move to main
    # def _train(self, dataset, epochs=200, batch_size=20, learning_rate=0.001, shuffle=True, pos_samples=0.5): # move to main

        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        loss_fn = nn.MSELoss()
    #     optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
    #     loss_fn = nn.MSELoss()

        for _ in range(epochs):
    #     for _ in range(epochs):
            
            training_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    #         training_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

            for batch in training_dataloader:
    #         for batch in training_dataloader:

                optimizer.zero_grad()
    #             optimizer.zero_grad()
            
    #                 for datapoint in batch:

    #                     r = np.random.random()
    #                     if r > pos_samples:

    #                         k = np.random.choice(len(datapoint.p_searchspace)) 
    #                         sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool())  # not padded sent embeddngs

    #                         score = self.__call__(sample, datapoint.gold_sent_vecs) 
    #                         loss = loss_fn(score, datapoint.top_rouge[k])

    #                     else:
    #                         if len(datapoint.sent_vecs) >= 3: 
    #                             narray = np.random.choice(len(datapoint.sent_vecs), 3, replace = False) 
    #                             narray.sort()
    #                             sample = datapoint.sent_vecs[narray]
    #                         else:
    #                             continue # handle len(sent_vecs) < 3 

    #                         score = self.__call__(sample, datapoint.gold_sent_vecs)
    #                         loss = loss_fn(score, utils.rouge(raw_document[narray]), raw_summary)) 
    #                         # rouge score berechnen für negative sample => besser wäre externes berechnen und speichern?

    #                     loss.backward()
                
    #             optimizer.step()
            
    #         # eval
    #         # test with rouge

def training_epoch(self, dataloader, learning_rate=None):

        if learning_rate != None:
            for g in self.optimizer.param_groups:
                g['lr'] = learning_rate 

        self.train()
        pos_samples= 0.5 


        epoch_loss = 0.0
            
        for batch in train_dataloader:

            self.optimizer.zero_grad()
        
            for datapoint in batch:

@@ -243,7 +289,7 @@ class Critic(nn.Module):
                    sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool())  # not padded sent embeddngs

                    score = self.__call__(sample, datapoint.gold_sent_vecs) 
                            loss = loss_fn(score, datapoint.top_rouge[k])
                    loss = self.loss_fn(score, datapoint.top_rouge[k])

                else:
                    if len(datapoint.sent_vecs) >= 3: 
@@ -254,15 +300,51 @@ class Critic(nn.Module):
                        continue # handle len(sent_vecs) < 3 

                    score = self.__call__(sample, datapoint.gold_sent_vecs)
                            loss = loss_fn(score, utils.rouge(raw_document[narray]), raw_summary)) 
                    loss = self.loss_fn(score, rouge(datapoint.raw_document[narray]), datapoint.raw_summary)
                    # rouge score berechnen für negative sample => besser wäre externes berechnen und speichern?

                epoch_loss += loss.item()
                
                loss.backward()
            
                optimizer.step()
            self.optimizer.step()
        
        return epoch_loss / len(dataloader.dataset)

    def test(self, dataloader):
        self.eval()
        pos_samples= 0.5 


        epoch_loss = 0.0
        with torch.no_grad():

            for batch in dataloader:
            
            # eval
                for datapoint in batch:

            # test with rouge
                    r = np.random.random()
                    if r > pos_samples:

                        k = np.random.choice(len(datapoint.p_searchspace)) 
                        sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool())  # not padded sent embeddngs

                        score = self.__call__(sample, datapoint.gold_sent_vecs) 
                        loss = self.loss_fn(score, datapoint.top_rouge[k])

                    else:
                        if len(datapoint.sent_vecs) >= 3: 
                            narray = np.random.choice(len(datapoint.sent_vecs), 3, replace = False) 
                            narray.sort()
                            sample = datapoint.sent_vecs[narray]
                        else:
                            continue # handle len(sent_vecs) < 3 

                        score = self.__call__(sample, datapoint.gold_sent_vecs)
                        loss = self.loss_fn(score, rouge(datapoint.raw_document[narray]), datapoint.raw_summary)
                        # rouge score berechnen für negative sample => besser wäre externes berechnen und speichern?

                    epoch_loss += loss.item()
        
        return epoch_loss / len(dataloader.dataset)