Commit 74936156 authored by Mayumi Ohta's avatar Mayumi Ohta
Browse files

don't call 'model.forward()', use 'model()' interface instead.

parent 04907cb7
Loading
Loading
Loading
Loading
+11 −10
Original line number Diff line number Diff line
@@ -21,21 +21,21 @@ class Batch:
        :param pad_index:
        :param use_cuda:
        """
        self.src, self.src_lengths = torch_batch.src
        self.src, self.src_length = torch_batch.src
        self.src_mask = (self.src != pad_index).unsqueeze(1)
        self.nseqs = self.src.size(0)
        self.trg_input = None
        self.trg = None
        self.trg_mask = None
        self.trg_lengths = None
        self.trg_length = None
        self.ntokens = None
        self.use_cuda = use_cuda

        if hasattr(torch_batch, "trg"):
            trg, trg_lengths = torch_batch.trg
            trg, trg_length = torch_batch.trg
            # trg_input is used for teacher forcing, last one is cut off
            self.trg_input = trg[:, :-1]
            self.trg_lengths = trg_lengths
            self.trg_length = trg_length
            # trg is used for loss computation, shifted by one since BOS
            self.trg = trg[:, 1:]
            # we exclude the padded areas from the loss computation
@@ -53,40 +53,41 @@ class Batch:
        """
        self.src = self.src.cuda()
        self.src_mask = self.src_mask.cuda()
        self.src_length = self.src_length.cuda()

        if self.trg_input is not None:
            self.trg_input = self.trg_input.cuda()
            self.trg = self.trg.cuda()
            self.trg_mask = self.trg_mask.cuda()

    def sort_by_src_lengths(self):
    def sort_by_src_length(self):
        """
        Sort by src length (descending) and return index to revert sort

        :return:
        """
        _, perm_index = self.src_lengths.sort(0, descending=True)
        _, perm_index = self.src_length.sort(0, descending=True)
        rev_index = [0]*perm_index.size(0)
        for new_pos, old_pos in enumerate(perm_index.cpu().numpy()):
            rev_index[old_pos] = new_pos

        sorted_src_lengths = self.src_lengths[perm_index]
        sorted_src_length = self.src_length[perm_index]
        sorted_src = self.src[perm_index]
        sorted_src_mask = self.src_mask[perm_index]
        if self.trg_input is not None:
            sorted_trg_input = self.trg_input[perm_index]
            sorted_trg_lengths = self.trg_lengths[perm_index]
            sorted_trg_length = self.trg_length[perm_index]
            sorted_trg_mask = self.trg_mask[perm_index]
            sorted_trg = self.trg[perm_index]

        self.src = sorted_src
        self.src_lengths = sorted_src_lengths
        self.src_length = sorted_src_length
        self.src_mask = sorted_src_mask

        if self.trg_input is not None:
            self.trg_input = sorted_trg_input
            self.trg_mask = sorted_trg_mask
            self.trg_lengths = sorted_trg_lengths
            self.trg_length = sorted_trg_length
            self.trg = sorted_trg

        if self.use_cuda:
+97 −92
Original line number Diff line number Diff line
@@ -2,9 +2,6 @@
"""
Module to represents whole models
"""

import numpy as np

import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
@@ -14,9 +11,7 @@ from joeynmt.embeddings import Embeddings
from joeynmt.encoders import Encoder, RecurrentEncoder, TransformerEncoder
from joeynmt.decoders import Decoder, RecurrentDecoder, TransformerDecoder
from joeynmt.constants import PAD_TOKEN, EOS_TOKEN, BOS_TOKEN
from joeynmt.search import beam_search, greedy
from joeynmt.vocabulary import Vocabulary
from joeynmt.batch import Batch
from joeynmt.helpers import ConfigurationError


@@ -53,11 +48,62 @@ class Model(nn.Module):
        self.bos_index = self.trg_vocab.stoi[BOS_TOKEN]
        self.pad_index = self.trg_vocab.stoi[PAD_TOKEN]
        self.eos_index = self.trg_vocab.stoi[EOS_TOKEN]
        self.loss_function = None # set by the TrainManager

    def forward(self, return_type: str = None, **kwargs) \
            -> (Tensor, Tensor, Tensor, Tensor):
        """ Interface for multi-gpu

        :param return_type: one of {"loss", "encode", "decode"}
        """
        if return_type is None:
            raise ValueError("Please specify return_type: "
                             "{`loss`, `encode`, `decode`}.")

        elif return_type == "loss":
            assert self.loss_function is not None

            out, _, _, _ = self._encode_decode(
                src=kwargs["src"],
                trg_input=kwargs["trg_input"],
                src_mask=kwargs["src_mask"],
                src_length=kwargs["src_length"],
                trg_mask=kwargs["trg_mask"])

            # compute log probs
            log_probs = F.log_softmax(out, dim=-1)

            # compute batch loss
            batch_loss = self.loss_function(log_probs, kwargs["trg"])

            # return batch loss = sum over all elements in batch that are not pad
            return batch_loss, None, None, None

        elif return_type == "encode":
            encoder_output, encoder_hidden = self._encode(
                src=kwargs["src"],
                src_length=kwargs["src_length"],
                src_mask=kwargs["src_mask"])
            # return encoder outputs
            return encoder_output, encoder_hidden, None, None

        elif return_type == "decode":
            outputs, hidden, att_probs, att_vectors = self._decode(
                trg_input=kwargs["trg_input"],
                encoder_output=kwargs["encoder_output"],
                encoder_hidden=kwargs["encoder_hidden"],
                src_mask=kwargs["src_mask"],
                unroll_steps=kwargs["unroll_steps"],
                decoder_hidden=kwargs["decoder_hidden"],
                att_vector=kwargs.get("att_vector", None),
                trg_mask=kwargs.get("trg_mask", None))
            # return decoder outputs
            return outputs, hidden, att_probs, att_vectors

    # pylint: disable=arguments-differ
    def forward(self, src: Tensor, trg_input: Tensor, src_mask: Tensor,
                src_lengths: Tensor, trg_mask: Tensor = None) -> (
        Tensor, Tensor, Tensor, Tensor):
    def _encode_decode(self, src: Tensor, trg_input: Tensor, src_mask: Tensor,
                       src_length: Tensor, trg_mask: Tensor = None) \
            -> (Tensor, Tensor, Tensor, Tensor):
        """
        First encodes the source sentence.
        Then produces the target one word at a time.
@@ -65,21 +111,21 @@ class Model(nn.Module):
        :param src: source input
        :param trg_input: target input
        :param src_mask: source mask
        :param src_lengths: length of source inputs
        :param src_length: length of source inputs
        :param trg_mask: target mask
        :return: decoder outputs
        """
        encoder_output, encoder_hidden = self.encode(src=src,
                                                     src_length=src_lengths,
        encoder_output, encoder_hidden = self._encode(src=src,
                                                      src_length=src_length,
                                                      src_mask=src_mask)
        unroll_steps = trg_input.size(1)
        return self.decode(encoder_output=encoder_output,
        return self._decode(encoder_output=encoder_output,
                            encoder_hidden=encoder_hidden,
                            src_mask=src_mask, trg_input=trg_input,
                            unroll_steps=unroll_steps,
                            trg_mask=trg_mask)

    def encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \
    def _encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \
            -> (Tensor, Tensor):
        """
        Encodes the source sentence.
@@ -91,10 +137,10 @@ class Model(nn.Module):
        """
        return self.encoder(self.src_embed(src), src_length, src_mask)

    def decode(self, encoder_output: Tensor, encoder_hidden: Tensor,
    def _decode(self, encoder_output: Tensor, encoder_hidden: Tensor,
                src_mask: Tensor, trg_input: Tensor,
                unroll_steps: int, decoder_hidden: Tensor = None,
               trg_mask: Tensor = None) \
                att_vector: Tensor = None, trg_mask: Tensor = None) \
            -> (Tensor, Tensor, Tensor, Tensor):
        """
        Decode, given an encoded source sentence.
@@ -105,6 +151,7 @@ class Model(nn.Module):
        :param trg_input: target inputs
        :param unroll_steps: number of steps to unrol the decoder for
        :param decoder_hidden: decoder hidden state (optional)
        :param att_vector: previous attention vector (optional)
        :param trg_mask: mask for target steps
        :return: decoder outputs (outputs, hidden, att_probs, att_vectors)
        """
@@ -114,74 +161,32 @@ class Model(nn.Module):
                            src_mask=src_mask,
                            unroll_steps=unroll_steps,
                            hidden=decoder_hidden,
                            prev_att_vector=att_vector,
                            trg_mask=trg_mask)

    def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \
            -> Tensor:
        """
        Compute non-normalized loss and number of tokens for a batch

        :param batch: batch to compute loss for
        :param loss_function: loss function, computes for input and target
            a scalar loss for the complete batch
        :return: batch_loss: sum of losses over non-pad elements in the batch
        """
        # pylint: disable=unused-variable
        out, hidden, att_probs, _ = self.forward(
            src=batch.src, trg_input=batch.trg_input,
            src_mask=batch.src_mask, src_lengths=batch.src_lengths,
            trg_mask=batch.trg_mask)

        # compute log probs
        log_probs = F.log_softmax(out, dim=-1)

        # compute batch loss
        batch_loss = loss_function(log_probs, batch.trg)
        # return batch loss = sum over all elements in batch that are not pad
        return batch_loss

    def run_batch(self, batch: Batch, max_output_length: int, beam_size: int,
                  beam_alpha: float) -> (np.array, np.array):
        """
        Get outputs and attentions scores for a given batch

        :param batch: batch to generate hypotheses for
        :param max_output_length: maximum length of hypotheses
        :param beam_size: size of the beam for beam search, if 0 use greedy
        :param beam_alpha: alpha value for beam search
        :return: stacked_output: hypotheses for batch,
            stacked_attention_scores: attention scores for batch
        """
        encoder_output, encoder_hidden = self.encode(
            batch.src, batch.src_lengths,
            batch.src_mask)

        # if maximum output length is not globally specified, adapt to src len
        if max_output_length is None:
            max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5)

        # greedy decoding
        if beam_size < 2:
            stacked_output, stacked_attention_scores = greedy(
                    encoder_hidden=encoder_hidden,
                    encoder_output=encoder_output, eos_index=self.eos_index,
                    src_mask=batch.src_mask, embed=self.trg_embed,
                    bos_index=self.bos_index, decoder=self.decoder,
                    max_output_length=max_output_length)
            # batch, time, max_src_length
        else:  # beam size
            stacked_output, stacked_attention_scores = \
                    beam_search(
                        size=beam_size, encoder_output=encoder_output,
                        encoder_hidden=encoder_hidden,
                        src_mask=batch.src_mask, embed=self.trg_embed,
                        max_output_length=max_output_length,
                        alpha=beam_alpha, eos_index=self.eos_index,
                        pad_index=self.pad_index,
                        bos_index=self.bos_index,
                        decoder=self.decoder)

        return stacked_output, stacked_attention_scores
#    def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \
#            -> Tensor:
#        """
#        Compute non-normalized loss and number of tokens for a batch
#
#        :param batch: batch to compute loss for
#        :param loss_function: loss function, computes for input and target
#            a scalar loss for the complete batch
#        :return: batch_loss: sum of losses over non-pad elements in the batch
#        """
#        # pylint: disable=unused-variable
#        out, hidden, att_probs, _ = self.encode_decode(
#            src=batch.src, trg_input=batch.trg_input,
#            src_mask=batch.src_mask, src_length=batch.src_length,
#            trg_mask=batch.trg_mask)
#
#        # compute log probs
#        log_probs = F.log_softmax(out, dim=-1)
#
#        # compute batch loss
#        batch_loss = loss_function(log_probs, batch.trg)
#        # return batch loss = sum over all elements in batch that are not pad
#        return batch_loss

    def __repr__(self) -> str:
        """
+16 −13
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from joeynmt.helpers import bpe_postprocess, load_config, make_logger,\
    get_latest_checkpoint, load_checkpoint, store_attention_plots
from joeynmt.metrics import bleu, chrf, token_accuracy, sequence_accuracy
from joeynmt.model import build_model, Model
from joeynmt.search import run_batch
from joeynmt.batch import Batch
from joeynmt.data import load_data, make_data_iter, MonoDataset
from joeynmt.constants import UNK_TOKEN, PAD_TOKEN, EOS_TOKEN
@@ -30,7 +31,7 @@ def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     compute_loss: bool = False,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True,
@@ -41,7 +42,7 @@ def validate_on_data(model: Model, data: Dataset,
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    If `compute_loss` is True and references are given,
    also compute the loss.

    :param model: model module
@@ -52,7 +53,7 @@ def validate_on_data(model: Model, data: Dataset,
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
    :param compute_loss: whether to computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
@@ -99,20 +100,22 @@ def validate_on_data(model: Model, data: Dataset,

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()
            sort_reverse_index = batch.sort_by_src_length()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
            if compute_loss and batch.trg is not None:
                batch_loss, _, _, _ = model(
                    return_type="loss", src=batch.src, trg=batch.trg,
                    trg_input=batch.trg_input, trg_mask=batch.trg_mask,
                    src_mask=batch.src_mask, src_length=batch.src_length)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch, beam_size=beam_size, beam_alpha=beam_alpha,
                max_output_length=max_output_length)
            output, attention_scores = run_batch(
                model=model, batch=batch, beam_size=beam_size,
                beam_alpha=beam_alpha, max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
@@ -122,7 +125,7 @@ def validate_on_data(model: Model, data: Dataset,

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
        if compute_loss and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
@@ -272,7 +275,7 @@ def test(cfg_file,
            model, data=data_set, batch_size=batch_size,
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            use_cuda=use_cuda, compute_loss=False, beam_size=beam_size,
            beam_alpha=beam_alpha, postprocess=postprocess,
            bpe_type=bpe_type, sacrebleu=sacrebleu)
        #pylint: enable=unused-variable
@@ -355,7 +358,7 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None:
            model, data=test_data, batch_size=batch_size,
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric="",
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            use_cuda=use_cuda, compute_loss=False, beam_size=beam_size,
            beam_alpha=beam_alpha, postprocess=postprocess,
            bpe_type=bpe_type, sacrebleu=sacrebleu)
        return hypotheses
+108 −76

File changed.

Preview size limit exceeded, changes collapsed.

+14 −11
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ class TrainManager:

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
        self.model.loss_function = XentLoss(pad_index=self.pad_index,
                                            smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
@@ -156,7 +156,7 @@ class TrainManager:
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()
            #self.loss.cuda()

        # initialize accumalted batch loss (needed for batch_multiplier)
        #self.norm_batch_loss_accumulated = 0
@@ -396,9 +396,8 @@ class TrainManager:
                        epoch_no + 1, epoch_loss)
        else:
            logger.info('Training ended after %3d epochs.', epoch_no + 1)
        logger.info('Best validation result (greedy) at step '
                         '%8d: %6.2f %s.', self.best_ckpt_iteration,
                         self.best_ckpt_score,
        logger.info('Best validation result (greedy) at step %8d: %6.2f %s.',
                    self.best_ckpt_iteration, self.best_ckpt_score,
                    self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
@@ -414,8 +413,12 @@ class TrainManager:
        self.model.train()

        # get loss
        batch_loss = self.model.get_loss_for_batch(
            batch=batch, loss_function=self.loss)
        #batch_loss = self.model.get_loss_for_batch(
        #    batch=batch, loss_function=self.loss)
        batch_loss, _, _, _ = self.model(
            return_type="loss", src=batch.src, trg=batch.trg,
            trg_input=batch.trg_input, src_mask=batch.src_mask,
            src_length=batch.src_length, trg_mask=batch.trg_mask)

        # normalize batch loss
        if self.normalization == "batch":
@@ -456,7 +459,7 @@ class TrainManager:
                level=self.level, model=self.model,
                use_cuda=self.use_cuda,
                max_output_length=self.max_output_length,
                loss_function=self.loss,
                compute_loss=True,
                beam_size=1,  # greedy validations
                batch_type=self.eval_batch_type,
                postprocess=True,   # always remove BPE for validation
Loading