Loading models.py +48 −3 Original line number Diff line number Diff line Loading @@ -206,11 +206,56 @@ class ActorCriticSummarisationModel(SummarisationModel): def forward() def training_epoch(self, dataloader, learning_rate=None): # def scheduler? or global variable? if learning_rate != None: for g in self.optimizer.param_groups: g['lr'] = learning_rate self.train() epoch_loss = 0.0 epoch_rouge = 0.0 for batch in dataloader: self.optimizer.zero_grad() def train() for datapoint in batch: # check if dp to gpu is OK datapoint = datapoint.to(self.device) # device definiert in main_ActorOnly.py def test() try: # Prevent breakdown for inapt datapoints # documents with empty content! if len(datapoint.raw_document) == 0 or len(datapoint.raw_summary) == 0: print("Warning! This datapoint has an empty document or an empty summary") continue _, probs = self.__call__(datapoint.document) o = datapoint.p_searchspace @ torch.log(probs) + datapoint.n_searchspace @ torch.log(1 - probs) idx_sample = torch.argmax(o) V = self.critic(datapoint.sent_vecs.masked_select(datapoint.p_searchspace[idx_sample].bool()), datapoint.gold_sent_vecs) loss_actor = - V.detach() * o[idx_sample] # backward nur für Actor loss_critic = self.critic.loss_fn(V, datapoint.top_rouge[idx_sample]) loss_actor.backward() loss_critic.backward() epoch_loss += loss.item() epoch_rouge += datapoint.top_rouge[idx_sample] except Exception as e: traceback.print_exception(*sys.exc_info()) continue self.optimizer.step() return epoch_loss / len(dataloader.dataset), epoch_rouge / len(dataloader.dataset) Loading Loading
models.py +48 −3 Original line number Diff line number Diff line Loading @@ -206,11 +206,56 @@ class ActorCriticSummarisationModel(SummarisationModel): def forward() def training_epoch(self, dataloader, learning_rate=None): # def scheduler? or global variable? if learning_rate != None: for g in self.optimizer.param_groups: g['lr'] = learning_rate self.train() epoch_loss = 0.0 epoch_rouge = 0.0 for batch in dataloader: self.optimizer.zero_grad() def train() for datapoint in batch: # check if dp to gpu is OK datapoint = datapoint.to(self.device) # device definiert in main_ActorOnly.py def test() try: # Prevent breakdown for inapt datapoints # documents with empty content! if len(datapoint.raw_document) == 0 or len(datapoint.raw_summary) == 0: print("Warning! This datapoint has an empty document or an empty summary") continue _, probs = self.__call__(datapoint.document) o = datapoint.p_searchspace @ torch.log(probs) + datapoint.n_searchspace @ torch.log(1 - probs) idx_sample = torch.argmax(o) V = self.critic(datapoint.sent_vecs.masked_select(datapoint.p_searchspace[idx_sample].bool()), datapoint.gold_sent_vecs) loss_actor = - V.detach() * o[idx_sample] # backward nur für Actor loss_critic = self.critic.loss_fn(V, datapoint.top_rouge[idx_sample]) loss_actor.backward() loss_critic.backward() epoch_loss += loss.item() epoch_rouge += datapoint.top_rouge[idx_sample] except Exception as e: traceback.print_exception(*sys.exc_info()) continue self.optimizer.step() return epoch_loss / len(dataloader.dataset), epoch_rouge / len(dataloader.dataset) Loading