Commit aed3fa4c authored by Mayumi Ohta's avatar Mayumi Ohta
Browse files

don't pass logging object as func arg

parent 98c4d81c
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import random
import os
import os.path
from typing import Optional
import logging

from torchtext.datasets import TranslationDataset
from torchtext import data
@@ -15,6 +16,8 @@ from torchtext.data import Dataset, Iterator, Field
from joeynmt.constants import UNK_TOKEN, EOS_TOKEN, BOS_TOKEN, PAD_TOKEN
from joeynmt.vocabulary import build_vocab, Vocabulary

logger = logging.getLogger(__name__)


def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
                                  Vocabulary, Vocabulary):
@@ -51,6 +54,7 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    logger.info("loading training data...")
    src_field = data.Field(init_token=None, eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN, tokenize=tok_fun,
                           batch_first=True, lower=lowercase,
@@ -80,6 +84,7 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
    src_vocab_file = data_cfg.get("src_vocab", None)
    trg_vocab_file = data_cfg.get("trg_vocab", None)

    logger.info("building vocabulary...")
    src_vocab = build_vocab(field="src", min_freq=src_min_freq,
                            max_size=src_max_size,
                            dataset=train_data, vocab_file=src_vocab_file)
@@ -96,11 +101,13 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
            random_state=random.getstate())
        train_data = keep

    logger.info("loading dev data...")
    dev_data = TranslationDataset(path=dev_path,
                                  exts=("." + src_lang, "." + trg_lang),
                                  fields=(src_field, trg_field))
    test_data = None
    if test_path is not None:
        logger.info("loading test data...")
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(
@@ -112,6 +119,7 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
                                    field=src_field)
    src_field.vocab = src_vocab
    trg_field.vocab = trg_vocab
    logger.info("data loaded.")
    return train_data, dev_data, test_data, src_vocab, trg_vocab


+34 −28
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import errno
import shutil
import random
import logging
from logging import Logger
#from logging import Logger
from typing import Callable, Optional, List
import numpy as np

@@ -46,16 +46,19 @@ def make_model_dir(model_dir: str, overwrite=False) -> str:
    return model_dir


def make_logger(log_file: str = None) -> Logger:
def make_logger(log_file: str = None) -> None:
    """
    Create a logger for logging the training/testing process.

    :param log_file: path to file where log is stored as well
    :return: logger object
    """
    logger = logging.getLogger(__name__)
    logger = logging.getLogger("") # root logger

    # add handlers only once.
    if len(logger.handlers) == 0:
        logger.setLevel(level=logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s %(message)s')
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')

        if log_file is not None:
            fh = logging.FileHandler(log_file)
@@ -67,23 +70,25 @@ def make_logger(log_file: str = None) -> Logger:
        sh.setLevel(logging.INFO)
        sh.setFormatter(formatter)

    logging.getLogger("").addHandler(sh)
        #logging.getLogger("").addHandler(sh)
        logger.addHandler(sh)
        logger.info("Hello! This is Joey-NMT.")
    return logger
        #return logger


def log_cfg(cfg: dict, logger: Logger, prefix: str = "cfg") -> None:
def log_cfg(cfg: dict, prefix: str = "cfg") -> None:
    """
    Write configuration to log.

    :param cfg: configuration to log
    :param logger: logger that defines where log is written to
    #:param logger: logger that defines where log is written to
    :param prefix: prefix for logging
    """
    logger = logging.getLogger(__name__)
    for k, v in cfg.items():
        if isinstance(v, dict):
            p = '.'.join([prefix, k])
            log_cfg(v, logger, prefix=p)
            log_cfg(v, prefix=p)
        else:
            p = '.'.join([prefix, k])
            logger.info("{:34s} : {}".format(p, v))
@@ -124,8 +129,8 @@ def set_seed(seed: int) -> None:


def log_data_info(train_data: Dataset, valid_data: Dataset, test_data: Dataset,
                  src_vocab: Vocabulary, trg_vocab: Vocabulary,
                  logging_function: Callable[[str], None]) -> None:
                  src_vocab: Vocabulary, trg_vocab: Vocabulary) -> None:
                  #logging_function: Callable[[str], None]
    """
    Log statistics of data and vocabulary.

@@ -134,24 +139,25 @@ def log_data_info(train_data: Dataset, valid_data: Dataset, test_data: Dataset,
    :param test_data:
    :param src_vocab:
    :param trg_vocab:
    :param logging_function:
    #:param logging_function:
    """
    logging_function(
    logger = logging.getLogger(__name__)
    logger.info(
        "Data set sizes: \n\ttrain %d,\n\tvalid %d,\n\ttest %d",
            len(train_data), len(valid_data),
            len(test_data) if test_data is not None else 0)

    logging_function("First training example:\n\t[SRC] %s\n\t[TRG] %s",
    logger.info("First training example:\n\t[SRC] %s\n\t[TRG] %s",
        " ".join(vars(train_data[0])['src']),
        " ".join(vars(train_data[0])['trg']))

    logging_function("First 10 words (src): %s", " ".join(
    logger.info("First 10 words (src): %s", " ".join(
        '(%d) %s' % (i, t) for i, t in enumerate(src_vocab.itos[:10])))
    logging_function("First 10 words (trg): %s", " ".join(
    logger.info("First 10 words (trg): %s", " ".join(
        '(%d) %s' % (i, t) for i, t in enumerate(trg_vocab.itos[:10])))

    logging_function("Number of Src words (types): %d", len(src_vocab))
    logging_function("Number of Trg words (types): %d", len(trg_vocab))
    logger.info("Number of Src words (types): %d", len(src_vocab))
    logger.info("Number of Trg words (types): %d", len(trg_vocab))


def load_config(path="configs/default.yaml") -> dict:
+16 −13
Original line number Diff line number Diff line
@@ -5,7 +5,8 @@ This modules holds methods for generating predictions from a model.
import os
import sys
from typing import List, Optional
from logging import Logger
import logging
#from logging import Logger
import numpy as np

import torch
@@ -20,10 +21,12 @@ from joeynmt.data import load_data, make_data_iter, MonoDataset
from joeynmt.constants import UNK_TOKEN, PAD_TOKEN, EOS_TOKEN
from joeynmt.vocabulary import Vocabulary

logger = logging.getLogger(__name__)


# pylint: disable=too-many-arguments,too-many-locals,no-member
def validate_on_data(model: Model, data: Dataset,
                     logger: Logger,
                     #logger: Logger, # don't pass logger
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
@@ -40,7 +43,7 @@ def validate_on_data(model: Model, data: Dataset,
    also compute the loss.

    :param model: model module
    :param logger: logger
    #:param logger: logger # don't pass logger
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
@@ -170,8 +173,8 @@ def validate_on_data(model: Model, data: Dataset,
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False,
         logger: Logger = None) -> None:
         save_attention: bool = False) -> None:
         #logger: Logger = None # don't pass logger
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.
@@ -180,14 +183,13 @@ def test(cfg_file,
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = make_logger()

    cfg = load_config(cfg_file)

    if len(logger.handlers) == 0:
        make_logger(f'{cfg["training"]["model_dir"]}/test.log')

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

@@ -247,7 +249,7 @@ def test(cfg_file,
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha, logger=logger, postprocess=postprocess)
            beam_alpha=beam_alpha, postprocess=postprocess)
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
@@ -317,8 +319,6 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None:

        return test_data

    logger = make_logger()

    def _translate_data(test_data):
        """ Translates given dataset, using parameters from outer scope. """
        # pylint: disable=unused-variable
@@ -328,11 +328,14 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None:
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric="",
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha, logger=logger, postprocess=postprocess)
            beam_alpha=beam_alpha, postprocess=postprocess)
        return hypotheses

    cfg = load_config(cfg_file)

    #logger = make_logger()
    make_logger(f'{cfg["training"]["model_dir"]}/translation.log')

    # when checkpoint is not specified, take oldest from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
+41 −33
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import argparse
import time
import shutil
from typing import List
import logging
import os
import queue

@@ -33,6 +34,8 @@ from joeynmt.builders import build_optimizer, build_scheduler, \
    build_gradient_clipper
from joeynmt.prediction import test

logger = logging.getLogger(__name__)


# pylint: disable=too-many-instance-attributes
class TrainManager:
@@ -49,10 +52,9 @@ class TrainManager:
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.model_dir = train_config["model_dir"]
        assert os.path.exists(self.model_dir)
        #self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(
@@ -163,7 +165,7 @@ class TrainManager:
        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
@@ -200,7 +202,7 @@ class TrainManager:
            try:
                os.remove(to_delete)
            except FileNotFoundError:
                self.logger.warning("Wanted to delete old checkpoint %s but "
                logger.warning("Wanted to delete old checkpoint %s but "
                                    "file does not exist.", to_delete)

        self.ckpt_queue.put(model_path)
@@ -240,7 +242,7 @@ class TrainManager:
        if not reset_optimizer:
            self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            self.logger.info("Reset optimizer.")
            logger.info("Reset optimizer.")

        if not reset_scheduler:
            if model_checkpoint["scheduler_state"] is not None and \
@@ -248,7 +250,7 @@ class TrainManager:
                self.scheduler.load_state_dict(
                    model_checkpoint["scheduler_state"])
        else:
            self.logger.info("Reset scheduler.")
            logger.info("Reset scheduler.")

        # restore counts
        self.steps = model_checkpoint["steps"]
@@ -258,7 +260,7 @@ class TrainManager:
            self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]
        else:
            self.logger.info("Reset tracking of the best checkpoint.")
            logger.info("Reset tracking of the best checkpoint.")

        # move parameters to cuda
        if self.use_cuda:
@@ -286,7 +288,7 @@ class TrainManager:
            train_data) % (self.batch_multiplier * self.batch_size)

        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)
            logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)
@@ -348,7 +350,7 @@ class TrainManager:
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                    logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f",
                        epoch_no + 1, self.steps, batch_loss,
@@ -366,7 +368,7 @@ class TrainManager:
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            #logger=self.logger, # don't pass logger
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
@@ -397,11 +399,11 @@ class TrainManager:
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                        logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

@@ -425,7 +427,7 @@ class TrainManager:

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                    logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
@@ -449,16 +451,16 @@ class TrainManager:
                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
            logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info('Best validation result (greedy) at step '
            logger.info('Training ended after %3d epochs.', epoch_no + 1)
        logger.info('Best validation result (greedy) at step '
                         '%8d: %6.2f %s.', self.best_ckpt_iteration,
                         self.best_ckpt_score,
                         self.early_stopping_metric)
@@ -559,10 +561,10 @@ class TrainManager:
        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        n_params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info("Total params: %d", n_params)
        logger.info("Total params: %d", n_params)
        trainable_params = [n for (n, p) in self.model.named_parameters()
                            if p.requires_grad]
        self.logger.info("Trainable parameters: %s", sorted(trainable_params))
        logger.debug("Trainable parameters: %s", sorted(trainable_params))
        assert trainable_params

    def _log_examples(self, sources: List[str], hypotheses: List[str],
@@ -585,18 +587,18 @@ class TrainManager:
            if p >= len(sources):
                continue

            self.logger.info("Example #%d", p)
            logger.info("Example #%d", p)

            if sources_raw is not None:
                self.logger.debug("\tRaw source:     %s", sources_raw[p])
                logger.debug("\tRaw source:     %s", sources_raw[p])
            if references_raw is not None:
                self.logger.debug("\tRaw reference:  %s", references_raw[p])
                logger.debug("\tRaw reference:  %s", references_raw[p])
            if hypotheses_raw is not None:
                self.logger.debug("\tRaw hypothesis: %s", hypotheses_raw[p])
                logger.debug("\tRaw hypothesis: %s", hypotheses_raw[p])

            self.logger.info("\tSource:     %s", sources[p])
            self.logger.info("\tReference:  %s", references[p])
            self.logger.info("\tHypothesis: %s", hypotheses[p])
            logger.info("\tSource:     %s", sources[p])
            logger.info("\tReference:  %s", references[p])
            logger.info("\tHypothesis: %s", hypotheses[p])

    def _store_outputs(self, hypotheses: List[str]) -> None:
        """
@@ -619,6 +621,11 @@ def train(cfg_file: str) -> None:
    """
    cfg = load_config(cfg_file)

    # make logger
    model_dir = make_model_dir(cfg["training"]["model_dir"],
                   overwrite=cfg["training"].get("overwrite", False))
    make_logger(f"{model_dir}/train.log")

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

@@ -636,13 +643,14 @@ def train(cfg_file: str) -> None:
    shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml")

    # log all entries of config
    log_cfg(cfg, trainer.logger)
    log_cfg(cfg) #,logger

    log_data_info(train_data=train_data, valid_data=dev_data,
                  test_data=test_data, src_vocab=src_vocab, trg_vocab=trg_vocab,
                  logging_function=trainer.logger.info)
                  test_data=test_data, src_vocab=src_vocab, trg_vocab=trg_vocab)
                  #logging_function=logger.info)

    trainer.logger.info(str(model))
    #trainer.logger.info(str(model))
    logger.info(str(model))

    # store the vocabs
    src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"])
@@ -658,7 +666,7 @@ def train(cfg_file: str) -> None:
    ckpt = "{}/{}.ckpt".format(trainer.model_dir, trainer.best_ckpt_iteration)
    output_name = "{:08d}.hyps".format(trainer.best_ckpt_iteration)
    output_path = os.path.join(trainer.model_dir, output_name)
    test(cfg_file, ckpt=ckpt, output_path=output_path, logger=trainer.logger)
    test(cfg_file, ckpt=ckpt, output_path=output_path) #, logger=trainer.logger)


if __name__ == "__main__":