Commit c2fae59f authored by kreuzer's avatar kreuzer
Browse files

Aktualisieren models.py

parent 50c35621
Loading
Loading
Loading
Loading
+48 −3
Original line number Diff line number Diff line
@@ -206,11 +206,56 @@ class ActorCriticSummarisationModel(SummarisationModel):
        


    def forward()
    def training_epoch(self, dataloader, learning_rate=None): # def scheduler? or global variable?

        if learning_rate != None:
            for g in self.optimizer.param_groups:
                g['lr'] = learning_rate 
        
        self.train()

        epoch_loss = 0.0
        epoch_rouge = 0.0
            
        for batch in dataloader:

            self.optimizer.zero_grad()

    def train()
            for datapoint in batch:
                # check if dp to gpu is OK 
                datapoint = datapoint.to(self.device) # device definiert in main_ActorOnly.py

    def test()
                try:    # Prevent breakdown for inapt datapoints
                    # documents with empty content!
                    if len(datapoint.raw_document) == 0 or len(datapoint.raw_summary) == 0: 
                        print("Warning! This datapoint has an empty document or an empty summary")
                        continue

                    _, probs = self.__call__(datapoint.document)

                    o = datapoint.p_searchspace @ torch.log(probs) + datapoint.n_searchspace @ torch.log(1 - probs)

                    idx_sample = torch.argmax(o)

                    V = self.critic(datapoint.sent_vecs.masked_select(datapoint.p_searchspace[idx_sample].bool()), datapoint.gold_sent_vecs)

                    loss_actor = - V.detach() * o[idx_sample] # backward nur für Actor

                    loss_critic = self.critic.loss_fn(V, datapoint.top_rouge[idx_sample])

                    loss_actor.backward()
                    loss_critic.backward()

                    epoch_loss += loss.item()
                    epoch_rouge += datapoint.top_rouge[idx_sample]
                
                except Exception as e:
                    traceback.print_exception(*sys.exc_info())
                    continue
            
            self.optimizer.step()
        
        return epoch_loss / len(dataloader.dataset), epoch_rouge / len(dataloader.dataset)