Loading joeynmt/batch.py +11 −10 Original line number Diff line number Diff line Loading @@ -21,21 +21,21 @@ class Batch: :param pad_index: :param use_cuda: """ self.src, self.src_lengths = torch_batch.src self.src, self.src_length = torch_batch.src self.src_mask = (self.src != pad_index).unsqueeze(1) self.nseqs = self.src.size(0) self.trg_input = None self.trg = None self.trg_mask = None self.trg_lengths = None self.trg_length = None self.ntokens = None self.use_cuda = use_cuda if hasattr(torch_batch, "trg"): trg, trg_lengths = torch_batch.trg trg, trg_length = torch_batch.trg # trg_input is used for teacher forcing, last one is cut off self.trg_input = trg[:, :-1] self.trg_lengths = trg_lengths self.trg_length = trg_length # trg is used for loss computation, shifted by one since BOS self.trg = trg[:, 1:] # we exclude the padded areas from the loss computation Loading @@ -53,40 +53,41 @@ class Batch: """ self.src = self.src.cuda() self.src_mask = self.src_mask.cuda() self.src_length = self.src_length.cuda() if self.trg_input is not None: self.trg_input = self.trg_input.cuda() self.trg = self.trg.cuda() self.trg_mask = self.trg_mask.cuda() def sort_by_src_lengths(self): def sort_by_src_length(self): """ Sort by src length (descending) and return index to revert sort :return: """ _, perm_index = self.src_lengths.sort(0, descending=True) _, perm_index = self.src_length.sort(0, descending=True) rev_index = [0]*perm_index.size(0) for new_pos, old_pos in enumerate(perm_index.cpu().numpy()): rev_index[old_pos] = new_pos sorted_src_lengths = self.src_lengths[perm_index] sorted_src_length = self.src_length[perm_index] sorted_src = self.src[perm_index] sorted_src_mask = self.src_mask[perm_index] if self.trg_input is not None: sorted_trg_input = self.trg_input[perm_index] sorted_trg_lengths = self.trg_lengths[perm_index] sorted_trg_length = self.trg_length[perm_index] sorted_trg_mask = self.trg_mask[perm_index] sorted_trg = self.trg[perm_index] self.src = sorted_src self.src_lengths = sorted_src_lengths self.src_length = sorted_src_length self.src_mask = sorted_src_mask if self.trg_input is not None: self.trg_input = sorted_trg_input self.trg_mask = sorted_trg_mask self.trg_lengths = sorted_trg_lengths self.trg_length = sorted_trg_length self.trg = sorted_trg if self.use_cuda: Loading joeynmt/model.py +97 −92 Original line number Diff line number Diff line Loading @@ -2,9 +2,6 @@ """ Module to represents whole models """ import numpy as np import torch.nn as nn from torch import Tensor import torch.nn.functional as F Loading @@ -14,9 +11,7 @@ from joeynmt.embeddings import Embeddings from joeynmt.encoders import Encoder, RecurrentEncoder, TransformerEncoder from joeynmt.decoders import Decoder, RecurrentDecoder, TransformerDecoder from joeynmt.constants import PAD_TOKEN, EOS_TOKEN, BOS_TOKEN from joeynmt.search import beam_search, greedy from joeynmt.vocabulary import Vocabulary from joeynmt.batch import Batch from joeynmt.helpers import ConfigurationError Loading Loading @@ -53,11 +48,62 @@ class Model(nn.Module): self.bos_index = self.trg_vocab.stoi[BOS_TOKEN] self.pad_index = self.trg_vocab.stoi[PAD_TOKEN] self.eos_index = self.trg_vocab.stoi[EOS_TOKEN] self.loss_function = None # set by the TrainManager def forward(self, return_type: str = None, **kwargs) \ -> (Tensor, Tensor, Tensor, Tensor): """ Interface for multi-gpu :param return_type: one of {"loss", "encode", "decode"} """ if return_type is None: raise ValueError("Please specify return_type: " "{`loss`, `encode`, `decode`}.") elif return_type == "loss": assert self.loss_function is not None out, _, _, _ = self._encode_decode( src=kwargs["src"], trg_input=kwargs["trg_input"], src_mask=kwargs["src_mask"], src_length=kwargs["src_length"], trg_mask=kwargs["trg_mask"]) # compute log probs log_probs = F.log_softmax(out, dim=-1) # compute batch loss batch_loss = self.loss_function(log_probs, kwargs["trg"]) # return batch loss = sum over all elements in batch that are not pad return batch_loss, None, None, None elif return_type == "encode": encoder_output, encoder_hidden = self._encode( src=kwargs["src"], src_length=kwargs["src_length"], src_mask=kwargs["src_mask"]) # return encoder outputs return encoder_output, encoder_hidden, None, None elif return_type == "decode": outputs, hidden, att_probs, att_vectors = self._decode( trg_input=kwargs["trg_input"], encoder_output=kwargs["encoder_output"], encoder_hidden=kwargs["encoder_hidden"], src_mask=kwargs["src_mask"], unroll_steps=kwargs["unroll_steps"], decoder_hidden=kwargs["decoder_hidden"], att_vector=kwargs.get("att_vector", None), trg_mask=kwargs.get("trg_mask", None)) # return decoder outputs return outputs, hidden, att_probs, att_vectors # pylint: disable=arguments-differ def forward(self, src: Tensor, trg_input: Tensor, src_mask: Tensor, src_lengths: Tensor, trg_mask: Tensor = None) -> ( Tensor, Tensor, Tensor, Tensor): def _encode_decode(self, src: Tensor, trg_input: Tensor, src_mask: Tensor, src_length: Tensor, trg_mask: Tensor = None) \ -> (Tensor, Tensor, Tensor, Tensor): """ First encodes the source sentence. Then produces the target one word at a time. Loading @@ -65,21 +111,21 @@ class Model(nn.Module): :param src: source input :param trg_input: target input :param src_mask: source mask :param src_lengths: length of source inputs :param src_length: length of source inputs :param trg_mask: target mask :return: decoder outputs """ encoder_output, encoder_hidden = self.encode(src=src, src_length=src_lengths, encoder_output, encoder_hidden = self._encode(src=src, src_length=src_length, src_mask=src_mask) unroll_steps = trg_input.size(1) return self.decode(encoder_output=encoder_output, return self._decode(encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, trg_input=trg_input, unroll_steps=unroll_steps, trg_mask=trg_mask) def encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \ def _encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \ -> (Tensor, Tensor): """ Encodes the source sentence. Loading @@ -91,10 +137,10 @@ class Model(nn.Module): """ return self.encoder(self.src_embed(src), src_length, src_mask) def decode(self, encoder_output: Tensor, encoder_hidden: Tensor, def _decode(self, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, trg_input: Tensor, unroll_steps: int, decoder_hidden: Tensor = None, trg_mask: Tensor = None) \ att_vector: Tensor = None, trg_mask: Tensor = None) \ -> (Tensor, Tensor, Tensor, Tensor): """ Decode, given an encoded source sentence. Loading @@ -105,6 +151,7 @@ class Model(nn.Module): :param trg_input: target inputs :param unroll_steps: number of steps to unrol the decoder for :param decoder_hidden: decoder hidden state (optional) :param att_vector: previous attention vector (optional) :param trg_mask: mask for target steps :return: decoder outputs (outputs, hidden, att_probs, att_vectors) """ Loading @@ -114,74 +161,32 @@ class Model(nn.Module): src_mask=src_mask, unroll_steps=unroll_steps, hidden=decoder_hidden, prev_att_vector=att_vector, trg_mask=trg_mask) def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \ -> Tensor: """ Compute non-normalized loss and number of tokens for a batch :param batch: batch to compute loss for :param loss_function: loss function, computes for input and target a scalar loss for the complete batch :return: batch_loss: sum of losses over non-pad elements in the batch """ # pylint: disable=unused-variable out, hidden, att_probs, _ = self.forward( src=batch.src, trg_input=batch.trg_input, src_mask=batch.src_mask, src_lengths=batch.src_lengths, trg_mask=batch.trg_mask) # compute log probs log_probs = F.log_softmax(out, dim=-1) # compute batch loss batch_loss = loss_function(log_probs, batch.trg) # return batch loss = sum over all elements in batch that are not pad return batch_loss def run_batch(self, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float) -> (np.array, np.array): """ Get outputs and attentions scores for a given batch :param batch: batch to generate hypotheses for :param max_output_length: maximum length of hypotheses :param beam_size: size of the beam for beam search, if 0 use greedy :param beam_alpha: alpha value for beam search :return: stacked_output: hypotheses for batch, stacked_attention_scores: attention scores for batch """ encoder_output, encoder_hidden = self.encode( batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5) # greedy decoding if beam_size < 2: stacked_output, stacked_attention_scores = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, eos_index=self.eos_index, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, max_output_length=max_output_length) # batch, time, max_src_length else: # beam size stacked_output, stacked_attention_scores = \ beam_search( size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, embed=self.trg_embed, max_output_length=max_output_length, alpha=beam_alpha, eos_index=self.eos_index, pad_index=self.pad_index, bos_index=self.bos_index, decoder=self.decoder) return stacked_output, stacked_attention_scores # def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \ # -> Tensor: # """ # Compute non-normalized loss and number of tokens for a batch # # :param batch: batch to compute loss for # :param loss_function: loss function, computes for input and target # a scalar loss for the complete batch # :return: batch_loss: sum of losses over non-pad elements in the batch # """ # # pylint: disable=unused-variable # out, hidden, att_probs, _ = self.encode_decode( # src=batch.src, trg_input=batch.trg_input, # src_mask=batch.src_mask, src_length=batch.src_length, # trg_mask=batch.trg_mask) # # # compute log probs # log_probs = F.log_softmax(out, dim=-1) # # # compute batch loss # batch_loss = loss_function(log_probs, batch.trg) # # return batch loss = sum over all elements in batch that are not pad # return batch_loss def __repr__(self) -> str: """ Loading joeynmt/prediction.py +16 −13 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ from joeynmt.helpers import bpe_postprocess, load_config, make_logger,\ get_latest_checkpoint, load_checkpoint, store_attention_plots from joeynmt.metrics import bleu, chrf, token_accuracy, sequence_accuracy from joeynmt.model import build_model, Model from joeynmt.search import run_batch from joeynmt.batch import Batch from joeynmt.data import load_data, make_data_iter, MonoDataset from joeynmt.constants import UNK_TOKEN, PAD_TOKEN, EOS_TOKEN Loading @@ -30,7 +31,7 @@ def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, compute_loss: bool = False, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True, Loading @@ -41,7 +42,7 @@ def validate_on_data(model: Model, data: Dataset, List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, If `compute_loss` is True and references are given, also compute the loss. :param model: model module Loading @@ -52,7 +53,7 @@ def validate_on_data(model: Model, data: Dataset, :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss :param compute_loss: whether to computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). Loading Loading @@ -99,20 +100,22 @@ def validate_on_data(model: Model, data: Dataset, batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() sort_reverse_index = batch.sort_by_src_length() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) if compute_loss and batch.trg is not None: batch_loss, _, _, _ = model( return_type="loss", src=batch.src, trg=batch.trg, trg_input=batch.trg_input, trg_mask=batch.trg_mask, src_mask=batch.src_mask, src_length=batch.src_length) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) output, attention_scores = run_batch( model=model, batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) Loading @@ -122,7 +125,7 @@ def validate_on_data(model: Model, data: Dataset, assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: if compute_loss and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob Loading Loading @@ -272,7 +275,7 @@ def test(cfg_file, model, data=data_set, batch_size=batch_size, batch_type=batch_type, level=level, max_output_length=max_output_length, eval_metric=eval_metric, use_cuda=use_cuda, loss_function=None, beam_size=beam_size, use_cuda=use_cuda, compute_loss=False, beam_size=beam_size, beam_alpha=beam_alpha, postprocess=postprocess, bpe_type=bpe_type, sacrebleu=sacrebleu) #pylint: enable=unused-variable Loading Loading @@ -355,7 +358,7 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None: model, data=test_data, batch_size=batch_size, batch_type=batch_type, level=level, max_output_length=max_output_length, eval_metric="", use_cuda=use_cuda, loss_function=None, beam_size=beam_size, use_cuda=use_cuda, compute_loss=False, beam_size=beam_size, beam_alpha=beam_alpha, postprocess=postprocess, bpe_type=bpe_type, sacrebleu=sacrebleu) return hypotheses Loading joeynmt/search.py +108 −76 File changed.Preview size limit exceeded, changes collapsed. Show changes joeynmt/training.py +14 −11 Original line number Diff line number Diff line Loading @@ -68,7 +68,7 @@ class TrainManager: # objective self.label_smoothing = train_config.get("label_smoothing", 0.0) self.loss = XentLoss(pad_index=self.pad_index, self.model.loss_function = XentLoss(pad_index=self.pad_index, smoothing=self.label_smoothing) self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens", "none"]: Loading Loading @@ -156,7 +156,7 @@ class TrainManager: self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() #self.loss.cuda() # initialize accumalted batch loss (needed for batch_multiplier) #self.norm_batch_loss_accumulated = 0 Loading Loading @@ -396,9 +396,8 @@ class TrainManager: epoch_no + 1, epoch_loss) else: logger.info('Training ended after %3d epochs.', epoch_no + 1) logger.info('Best validation result (greedy) at step ' '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, logger.info('Best validation result (greedy) at step %8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric) self.tb_writer.close() # close Tensorboard writer Loading @@ -414,8 +413,12 @@ class TrainManager: self.model.train() # get loss batch_loss = self.model.get_loss_for_batch( batch=batch, loss_function=self.loss) #batch_loss = self.model.get_loss_for_batch( # batch=batch, loss_function=self.loss) batch_loss, _, _, _ = self.model( return_type="loss", src=batch.src, trg=batch.trg, trg_input=batch.trg_input, src_mask=batch.src_mask, src_length=batch.src_length, trg_mask=batch.trg_mask) # normalize batch loss if self.normalization == "batch": Loading Loading @@ -456,7 +459,7 @@ class TrainManager: level=self.level, model=self.model, use_cuda=self.use_cuda, max_output_length=self.max_output_length, loss_function=self.loss, compute_loss=True, beam_size=1, # greedy validations batch_type=self.eval_batch_type, postprocess=True, # always remove BPE for validation Loading Loading
joeynmt/batch.py +11 −10 Original line number Diff line number Diff line Loading @@ -21,21 +21,21 @@ class Batch: :param pad_index: :param use_cuda: """ self.src, self.src_lengths = torch_batch.src self.src, self.src_length = torch_batch.src self.src_mask = (self.src != pad_index).unsqueeze(1) self.nseqs = self.src.size(0) self.trg_input = None self.trg = None self.trg_mask = None self.trg_lengths = None self.trg_length = None self.ntokens = None self.use_cuda = use_cuda if hasattr(torch_batch, "trg"): trg, trg_lengths = torch_batch.trg trg, trg_length = torch_batch.trg # trg_input is used for teacher forcing, last one is cut off self.trg_input = trg[:, :-1] self.trg_lengths = trg_lengths self.trg_length = trg_length # trg is used for loss computation, shifted by one since BOS self.trg = trg[:, 1:] # we exclude the padded areas from the loss computation Loading @@ -53,40 +53,41 @@ class Batch: """ self.src = self.src.cuda() self.src_mask = self.src_mask.cuda() self.src_length = self.src_length.cuda() if self.trg_input is not None: self.trg_input = self.trg_input.cuda() self.trg = self.trg.cuda() self.trg_mask = self.trg_mask.cuda() def sort_by_src_lengths(self): def sort_by_src_length(self): """ Sort by src length (descending) and return index to revert sort :return: """ _, perm_index = self.src_lengths.sort(0, descending=True) _, perm_index = self.src_length.sort(0, descending=True) rev_index = [0]*perm_index.size(0) for new_pos, old_pos in enumerate(perm_index.cpu().numpy()): rev_index[old_pos] = new_pos sorted_src_lengths = self.src_lengths[perm_index] sorted_src_length = self.src_length[perm_index] sorted_src = self.src[perm_index] sorted_src_mask = self.src_mask[perm_index] if self.trg_input is not None: sorted_trg_input = self.trg_input[perm_index] sorted_trg_lengths = self.trg_lengths[perm_index] sorted_trg_length = self.trg_length[perm_index] sorted_trg_mask = self.trg_mask[perm_index] sorted_trg = self.trg[perm_index] self.src = sorted_src self.src_lengths = sorted_src_lengths self.src_length = sorted_src_length self.src_mask = sorted_src_mask if self.trg_input is not None: self.trg_input = sorted_trg_input self.trg_mask = sorted_trg_mask self.trg_lengths = sorted_trg_lengths self.trg_length = sorted_trg_length self.trg = sorted_trg if self.use_cuda: Loading
joeynmt/model.py +97 −92 Original line number Diff line number Diff line Loading @@ -2,9 +2,6 @@ """ Module to represents whole models """ import numpy as np import torch.nn as nn from torch import Tensor import torch.nn.functional as F Loading @@ -14,9 +11,7 @@ from joeynmt.embeddings import Embeddings from joeynmt.encoders import Encoder, RecurrentEncoder, TransformerEncoder from joeynmt.decoders import Decoder, RecurrentDecoder, TransformerDecoder from joeynmt.constants import PAD_TOKEN, EOS_TOKEN, BOS_TOKEN from joeynmt.search import beam_search, greedy from joeynmt.vocabulary import Vocabulary from joeynmt.batch import Batch from joeynmt.helpers import ConfigurationError Loading Loading @@ -53,11 +48,62 @@ class Model(nn.Module): self.bos_index = self.trg_vocab.stoi[BOS_TOKEN] self.pad_index = self.trg_vocab.stoi[PAD_TOKEN] self.eos_index = self.trg_vocab.stoi[EOS_TOKEN] self.loss_function = None # set by the TrainManager def forward(self, return_type: str = None, **kwargs) \ -> (Tensor, Tensor, Tensor, Tensor): """ Interface for multi-gpu :param return_type: one of {"loss", "encode", "decode"} """ if return_type is None: raise ValueError("Please specify return_type: " "{`loss`, `encode`, `decode`}.") elif return_type == "loss": assert self.loss_function is not None out, _, _, _ = self._encode_decode( src=kwargs["src"], trg_input=kwargs["trg_input"], src_mask=kwargs["src_mask"], src_length=kwargs["src_length"], trg_mask=kwargs["trg_mask"]) # compute log probs log_probs = F.log_softmax(out, dim=-1) # compute batch loss batch_loss = self.loss_function(log_probs, kwargs["trg"]) # return batch loss = sum over all elements in batch that are not pad return batch_loss, None, None, None elif return_type == "encode": encoder_output, encoder_hidden = self._encode( src=kwargs["src"], src_length=kwargs["src_length"], src_mask=kwargs["src_mask"]) # return encoder outputs return encoder_output, encoder_hidden, None, None elif return_type == "decode": outputs, hidden, att_probs, att_vectors = self._decode( trg_input=kwargs["trg_input"], encoder_output=kwargs["encoder_output"], encoder_hidden=kwargs["encoder_hidden"], src_mask=kwargs["src_mask"], unroll_steps=kwargs["unroll_steps"], decoder_hidden=kwargs["decoder_hidden"], att_vector=kwargs.get("att_vector", None), trg_mask=kwargs.get("trg_mask", None)) # return decoder outputs return outputs, hidden, att_probs, att_vectors # pylint: disable=arguments-differ def forward(self, src: Tensor, trg_input: Tensor, src_mask: Tensor, src_lengths: Tensor, trg_mask: Tensor = None) -> ( Tensor, Tensor, Tensor, Tensor): def _encode_decode(self, src: Tensor, trg_input: Tensor, src_mask: Tensor, src_length: Tensor, trg_mask: Tensor = None) \ -> (Tensor, Tensor, Tensor, Tensor): """ First encodes the source sentence. Then produces the target one word at a time. Loading @@ -65,21 +111,21 @@ class Model(nn.Module): :param src: source input :param trg_input: target input :param src_mask: source mask :param src_lengths: length of source inputs :param src_length: length of source inputs :param trg_mask: target mask :return: decoder outputs """ encoder_output, encoder_hidden = self.encode(src=src, src_length=src_lengths, encoder_output, encoder_hidden = self._encode(src=src, src_length=src_length, src_mask=src_mask) unroll_steps = trg_input.size(1) return self.decode(encoder_output=encoder_output, return self._decode(encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, trg_input=trg_input, unroll_steps=unroll_steps, trg_mask=trg_mask) def encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \ def _encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \ -> (Tensor, Tensor): """ Encodes the source sentence. Loading @@ -91,10 +137,10 @@ class Model(nn.Module): """ return self.encoder(self.src_embed(src), src_length, src_mask) def decode(self, encoder_output: Tensor, encoder_hidden: Tensor, def _decode(self, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, trg_input: Tensor, unroll_steps: int, decoder_hidden: Tensor = None, trg_mask: Tensor = None) \ att_vector: Tensor = None, trg_mask: Tensor = None) \ -> (Tensor, Tensor, Tensor, Tensor): """ Decode, given an encoded source sentence. Loading @@ -105,6 +151,7 @@ class Model(nn.Module): :param trg_input: target inputs :param unroll_steps: number of steps to unrol the decoder for :param decoder_hidden: decoder hidden state (optional) :param att_vector: previous attention vector (optional) :param trg_mask: mask for target steps :return: decoder outputs (outputs, hidden, att_probs, att_vectors) """ Loading @@ -114,74 +161,32 @@ class Model(nn.Module): src_mask=src_mask, unroll_steps=unroll_steps, hidden=decoder_hidden, prev_att_vector=att_vector, trg_mask=trg_mask) def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \ -> Tensor: """ Compute non-normalized loss and number of tokens for a batch :param batch: batch to compute loss for :param loss_function: loss function, computes for input and target a scalar loss for the complete batch :return: batch_loss: sum of losses over non-pad elements in the batch """ # pylint: disable=unused-variable out, hidden, att_probs, _ = self.forward( src=batch.src, trg_input=batch.trg_input, src_mask=batch.src_mask, src_lengths=batch.src_lengths, trg_mask=batch.trg_mask) # compute log probs log_probs = F.log_softmax(out, dim=-1) # compute batch loss batch_loss = loss_function(log_probs, batch.trg) # return batch loss = sum over all elements in batch that are not pad return batch_loss def run_batch(self, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float) -> (np.array, np.array): """ Get outputs and attentions scores for a given batch :param batch: batch to generate hypotheses for :param max_output_length: maximum length of hypotheses :param beam_size: size of the beam for beam search, if 0 use greedy :param beam_alpha: alpha value for beam search :return: stacked_output: hypotheses for batch, stacked_attention_scores: attention scores for batch """ encoder_output, encoder_hidden = self.encode( batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5) # greedy decoding if beam_size < 2: stacked_output, stacked_attention_scores = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, eos_index=self.eos_index, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, max_output_length=max_output_length) # batch, time, max_src_length else: # beam size stacked_output, stacked_attention_scores = \ beam_search( size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, embed=self.trg_embed, max_output_length=max_output_length, alpha=beam_alpha, eos_index=self.eos_index, pad_index=self.pad_index, bos_index=self.bos_index, decoder=self.decoder) return stacked_output, stacked_attention_scores # def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \ # -> Tensor: # """ # Compute non-normalized loss and number of tokens for a batch # # :param batch: batch to compute loss for # :param loss_function: loss function, computes for input and target # a scalar loss for the complete batch # :return: batch_loss: sum of losses over non-pad elements in the batch # """ # # pylint: disable=unused-variable # out, hidden, att_probs, _ = self.encode_decode( # src=batch.src, trg_input=batch.trg_input, # src_mask=batch.src_mask, src_length=batch.src_length, # trg_mask=batch.trg_mask) # # # compute log probs # log_probs = F.log_softmax(out, dim=-1) # # # compute batch loss # batch_loss = loss_function(log_probs, batch.trg) # # return batch loss = sum over all elements in batch that are not pad # return batch_loss def __repr__(self) -> str: """ Loading
joeynmt/prediction.py +16 −13 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ from joeynmt.helpers import bpe_postprocess, load_config, make_logger,\ get_latest_checkpoint, load_checkpoint, store_attention_plots from joeynmt.metrics import bleu, chrf, token_accuracy, sequence_accuracy from joeynmt.model import build_model, Model from joeynmt.search import run_batch from joeynmt.batch import Batch from joeynmt.data import load_data, make_data_iter, MonoDataset from joeynmt.constants import UNK_TOKEN, PAD_TOKEN, EOS_TOKEN Loading @@ -30,7 +31,7 @@ def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, compute_loss: bool = False, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True, Loading @@ -41,7 +42,7 @@ def validate_on_data(model: Model, data: Dataset, List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, If `compute_loss` is True and references are given, also compute the loss. :param model: model module Loading @@ -52,7 +53,7 @@ def validate_on_data(model: Model, data: Dataset, :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss :param compute_loss: whether to computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). Loading Loading @@ -99,20 +100,22 @@ def validate_on_data(model: Model, data: Dataset, batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() sort_reverse_index = batch.sort_by_src_length() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) if compute_loss and batch.trg is not None: batch_loss, _, _, _ = model( return_type="loss", src=batch.src, trg=batch.trg, trg_input=batch.trg_input, trg_mask=batch.trg_mask, src_mask=batch.src_mask, src_length=batch.src_length) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) output, attention_scores = run_batch( model=model, batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) Loading @@ -122,7 +125,7 @@ def validate_on_data(model: Model, data: Dataset, assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: if compute_loss and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob Loading Loading @@ -272,7 +275,7 @@ def test(cfg_file, model, data=data_set, batch_size=batch_size, batch_type=batch_type, level=level, max_output_length=max_output_length, eval_metric=eval_metric, use_cuda=use_cuda, loss_function=None, beam_size=beam_size, use_cuda=use_cuda, compute_loss=False, beam_size=beam_size, beam_alpha=beam_alpha, postprocess=postprocess, bpe_type=bpe_type, sacrebleu=sacrebleu) #pylint: enable=unused-variable Loading Loading @@ -355,7 +358,7 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None: model, data=test_data, batch_size=batch_size, batch_type=batch_type, level=level, max_output_length=max_output_length, eval_metric="", use_cuda=use_cuda, loss_function=None, beam_size=beam_size, use_cuda=use_cuda, compute_loss=False, beam_size=beam_size, beam_alpha=beam_alpha, postprocess=postprocess, bpe_type=bpe_type, sacrebleu=sacrebleu) return hypotheses Loading
joeynmt/training.py +14 −11 Original line number Diff line number Diff line Loading @@ -68,7 +68,7 @@ class TrainManager: # objective self.label_smoothing = train_config.get("label_smoothing", 0.0) self.loss = XentLoss(pad_index=self.pad_index, self.model.loss_function = XentLoss(pad_index=self.pad_index, smoothing=self.label_smoothing) self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens", "none"]: Loading Loading @@ -156,7 +156,7 @@ class TrainManager: self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() #self.loss.cuda() # initialize accumalted batch loss (needed for batch_multiplier) #self.norm_batch_loss_accumulated = 0 Loading Loading @@ -396,9 +396,8 @@ class TrainManager: epoch_no + 1, epoch_loss) else: logger.info('Training ended after %3d epochs.', epoch_no + 1) logger.info('Best validation result (greedy) at step ' '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, logger.info('Best validation result (greedy) at step %8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric) self.tb_writer.close() # close Tensorboard writer Loading @@ -414,8 +413,12 @@ class TrainManager: self.model.train() # get loss batch_loss = self.model.get_loss_for_batch( batch=batch, loss_function=self.loss) #batch_loss = self.model.get_loss_for_batch( # batch=batch, loss_function=self.loss) batch_loss, _, _, _ = self.model( return_type="loss", src=batch.src, trg=batch.trg, trg_input=batch.trg_input, src_mask=batch.src_mask, src_length=batch.src_length, trg_mask=batch.trg_mask) # normalize batch loss if self.normalization == "batch": Loading Loading @@ -456,7 +459,7 @@ class TrainManager: level=self.level, model=self.model, use_cuda=self.use_cuda, max_output_length=self.max_output_length, loss_function=self.loss, compute_loss=True, beam_size=1, # greedy validations batch_type=self.eval_batch_type, postprocess=True, # always remove BPE for validation Loading