Commit 04907cb7 authored by Mayumi Ohta's avatar Mayumi Ohta
Browse files

load datasets only once

parent a3f0bcac
Loading
Loading
Loading
Loading
+33 −27
Original line number Diff line number Diff line
@@ -19,8 +19,8 @@ from joeynmt.vocabulary import build_vocab, Vocabulary
logger = logging.getLogger(__name__)


def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
                                  Vocabulary, Vocabulary):
def load_data(data_cfg: dict, datasets: list = ["train", "dev", "test"])\
        -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
@@ -35,6 +35,7 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :param datasets: list of dataset names to load
    :return:
        - train_data: training dataset
        - dev_data: development dataset
@@ -45,8 +46,8 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
    # load data from files
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]
    train_path = data_cfg["train"]
    dev_path = data_cfg["dev"]
    train_path = data_cfg.get("train", None)
    dev_path = data_cfg.get("dev", None)
    test_path = data_cfg.get("test", None)
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
@@ -54,7 +55,6 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    logger.info("loading training data...")
    src_field = data.Field(init_token=None, eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN, tokenize=tok_fun,
                           batch_first=True, lower=lowercase,
@@ -67,6 +67,9 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
                           batch_first=True, lower=lowercase,
                           include_lengths=True)

    train_data = None
    if "train" in datasets and train_path is not None:
        logger.info("loading training data...")
        train_data = TranslationDataset(path=train_path,
                                        exts=("." + src_lang, "." + trg_lang),
                                        fields=(src_field, trg_field),
@@ -76,6 +79,15 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
                                        and len(vars(x)['trg'])
                                        <= max_sent_length)

        random_train_subset = data_cfg.get("random_train_subset", -1)
        if random_train_subset > -1:
            # select this many training examples randomly and discard the rest
            keep_ratio = random_train_subset / len(train_data)
            keep, _ = train_data.split(
                split_ratio=[keep_ratio, 1 - keep_ratio],
                random_state=random.getstate())
            train_data = keep

    src_max_size = data_cfg.get("src_voc_limit", sys.maxsize)
    src_min_freq = data_cfg.get("src_voc_min_freq", 1)
    trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize)
@@ -92,21 +104,15 @@ def load_data(data_cfg: dict) -> (Dataset, Dataset, Optional[Dataset],
                            max_size=trg_max_size,
                            dataset=train_data, vocab_file=trg_vocab_file)

    random_train_subset = data_cfg.get("random_train_subset", -1)
    if random_train_subset > -1:
        # select this many training examples randomly and discard the rest
        keep_ratio = random_train_subset / len(train_data)
        keep, _ = train_data.split(
            split_ratio=[keep_ratio, 1 - keep_ratio],
            random_state=random.getstate())
        train_data = keep

    dev_data = None
    if "dev" in datasets and dev_path is not None:
        logger.info("loading dev data...")
        dev_data = TranslationDataset(path=dev_path,
                                      exts=("." + src_lang, "." + trg_lang),
                                      fields=(src_field, trg_field))

    test_data = None
    if test_path is not None:
    if "test" in datasets and test_path is not None:
        logger.info("loading test data...")
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
+13 −5
Original line number Diff line number Diff line
@@ -180,7 +180,8 @@ def validate_on_data(model: Model, data: Dataset,
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False) -> None:
         save_attention: bool = False,
         datasets: dict = None) -> None:
         #logger: Logger = None # don't pass logger
    """
    Main test function. Handles loading a model from checkpoint, generating
@@ -222,10 +223,14 @@ def test(cfg_file,
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    if datasets is None:
        _, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

            data_cfg=cfg["data"], datasets=["dev", "test"])
        data_to_predict = {"dev": dev_data, "test": test_data}
    else: # avoid to load data again
        data_to_predict = {"dev": datasets["dev"], "test": datasets["test"]}
        src_vocab = datasets["src_vocab"]
        trg_vocab = datasets["trg_vocab"]

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)
@@ -257,7 +262,10 @@ def test(cfg_file,
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}

    for data_set_name, data_set in data_to_predict.items():
        if data_set is None:
            continue

        logger.info(f"Decoding on {data_set_name} set...")
        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
+3 −1
Original line number Diff line number Diff line
@@ -667,7 +667,9 @@ def train(cfg_file: str) -> None:
    ckpt = "{}/{}.ckpt".format(trainer.model_dir, trainer.best_ckpt_iteration)
    output_name = "{:08d}.hyps".format(trainer.best_ckpt_iteration)
    output_path = os.path.join(trainer.model_dir, output_name)
    test(cfg_file, ckpt=ckpt, output_path=output_path) #, logger=trainer.logger)
    datasets_to_test = {"dev": dev_data, "test": test_data,
                        "src_vocab": src_vocab, "trg_vocab": trg_vocab}
    test(cfg_file, ckpt=ckpt, output_path=output_path, datasets=datasets_to_test)


if __name__ == "__main__":