Commit 01dec5ef authored by Mayumi Ohta's avatar Mayumi Ohta
Browse files

support sentencepiece

parent cd38e156
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -175,14 +175,21 @@ def load_config(path="configs/default.yaml") -> dict:
    return cfg


def bpe_postprocess(string) -> str:
def bpe_postprocess(string, bpe_type="subword-nmt") -> str:
    """
    Post-processor for BPE output. Recombines BPE-split tokens.

    :param string:
    :param bpe_type: one of {"sentencepiece", "subword-nmt"}
    :return: post-processed string
    """
    return string.replace("@@ ", "")
    #return string.replace("@@ ", "")
    if bpe_type == "sentencepiece": #if "▁" in string:
        return string.replace(" ", "").replace("", " ").strip()
    elif bpe_type == "subword-nmt": #elif "@@" in string:
        return string.replace("@@ ", "").strip()
    else:
        return string.strip()


def store_attention_plots(attentions: np.array, targets: List[List[str]],
+31 −6
Original line number Diff line number Diff line
@@ -4,32 +4,57 @@ This module holds various MT evaluation metrics.
"""

import sacrebleu
from typing import List


def chrf(hypotheses, references):
def chrf(hypotheses, references, remove_whitespace=True):
    """
    Character F-score from sacrebleu

    :param hypotheses: list of hypotheses (strings)
    :param references: list of references (strings)
    :param remove_whitespace: (bool)
    :return:
    """
    return sacrebleu.corpus_chrf(hypotheses=hypotheses, references=references)
    return sacrebleu.corpus_chrf(hypotheses=hypotheses, references=[references],
                                 remove_whitespace=remove_whitespace).score


def bleu(hypotheses, references):
def bleu(hypotheses, references, tokenize="13a"):
    """
    Raw corpus BLEU from sacrebleu (without tokenization)

    :param hypotheses: list of hypotheses (strings)
    :param references: list of references (strings)
    :param tokenize: one of {'none', '13a', 'intl', 'zh', 'ja-mecab'}
    :return:
    """
    return sacrebleu.raw_corpus_bleu(sys_stream=hypotheses,
                                     ref_streams=[references]).score
    return sacrebleu.corpus_bleu(sys_stream=hypotheses, ref_streams=[references],
                                 tokenize=tokenize).score


def token_accuracy(hypotheses, references, level="word"):
def token_accuracy(hypotheses: List[List[str]], references: List[List[str]]) -> float:
    """
    Compute the accuracy of hypothesis tokens: correct tokens / all tokens
    Tokens are correct if they appear in the same position in the reference.

    :param hypotheses: list of hypotheses (List[str])
    :param references: list of references (List[str])
    :return: token accuracy (float)
    """
    correct_tokens = 0
    all_tokens = 0
    assert len(hypotheses) == len(references)
    for hyp, ref in zip(hypotheses, references):
        all_tokens += len(hyp)
        for h_i, r_i in zip(hyp, ref):
            # min(len(h), len(r)) tokens considered
            if h_i == r_i:
                correct_tokens += 1
    return (correct_tokens / all_tokens)*100 if all_tokens > 0 else 0.0


def _token_accuracy(hypotheses, references, level="word"):
    """
    Compute the accuracy of hypothesis tokens: correct tokens / all tokens
    Tokens are correct if they appear in the same position in the reference.
+43 −13
Original line number Diff line number Diff line
@@ -33,8 +33,10 @@ def validate_on_data(model: Model, data: Dataset,
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True
                     ) \
                     postprocess: bool = True,
                     bpe_type: str = "subword-nmt",
                     sacrebleu: dict = {"remove_whitespace": True,
                                        "tokenize": "13a"}) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
@@ -58,6 +60,8 @@ def validate_on_data(model: Model, data: Dataset,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations
    :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"}
    :param sacrebleu: sacrebleu options

    :return:
        - current_valid_score: current validation score [eval_metric],
@@ -139,10 +143,10 @@ def validate_on_data(model: Model, data: Dataset,

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v)
            valid_sources = [bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources]
            valid_references = [bpe_postprocess(v, bpe_type=bpe_type)
                                for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for
            valid_hypotheses = [bpe_postprocess(v, bpe_type=bpe_type) for
                                v in valid_hypotheses]

        # if references are given, evaluate against them
@@ -152,12 +156,15 @@ def validate_on_data(model: Model, data: Dataset,
            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
                current_valid_score = bleu(
                    valid_hypotheses, valid_references, tokenize=sacrebleu["tokenize"])
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
                current_valid_score = chrf(valid_hypotheses, valid_references,
                                           remove_whitespace=sacrebleu["remove_whitespace"])
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(
                    valid_hypotheses, valid_references, level=level)
                current_valid_score = token_accuracy( # supply List[List[str]] before join!
                    [t for t in decoded_valid], [t for t in data.trg])
            #        valid_hypotheses, valid_references, level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
@@ -235,10 +242,19 @@ def test(cfg_file,
        beam_size = cfg["testing"].get("beam_size", 1)
        beam_alpha = cfg["testing"].get("alpha", -1)
        postprocess = cfg["testing"].get("postprocess", True)
        bpe_type = cfg["testing"].get("bpe_type", "subword-nmt")
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
        if "sacrebleu" in cfg["testing"].keys():
            sacrebleu["remove_whitespace"] = cfg["testing"]["sacrebleu"]\
                .get("remove_whitespace", True)
            sacrebleu["tokenize"] = cfg["testing"]["sacrebleu"]\
                .get("tokenize", "13a")
    else:
        beam_size = 1
        beam_alpha = -1
        postprocess = True
        bpe_type = "subword-nmt"
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}

    for data_set_name, data_set in data_to_predict.items():

@@ -249,15 +265,19 @@ def test(cfg_file,
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha, postprocess=postprocess)
            beam_alpha=beam_alpha, postprocess=postprocess,
            bpe_type=bpe_type, sacrebleu=sacrebleu)
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
            decoding_description = "Greedy decoding" if beam_size < 2 else \
                "Beam search decoding with beam size = {} and alpha = {}".\
                    format(beam_size, beam_alpha)
            logger.info("%4s %s: %6.2f [%s]",
                        data_set_name, eval_metric, score, decoding_description)
            tokenizer_info = f"[{sacrebleu['tokenize']}]" \
                if eval_metric == "bleu" else ""
            logger.info("%4s %s%s: %6.2f [%s]",
                        data_set_name, eval_metric, tokenizer_info,
                        score, decoding_description)
        else:
            logger.info("No references given for %s -> no evaluation.",
                        data_set_name)
@@ -328,7 +348,8 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None:
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric="",
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha, postprocess=postprocess)
            beam_alpha=beam_alpha, postprocess=postprocess,
            bpe_type=bpe_type, sacrebleu=sacrebleu)
        return hypotheses

    cfg = load_config(cfg_file)
@@ -385,10 +406,19 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None:
        beam_size = cfg["testing"].get("beam_size", 1)
        beam_alpha = cfg["testing"].get("alpha", -1)
        postprocess = cfg["testing"].get("postprocess", True)
        bpe_type = cfg["testing"].get("bpe_type", "subword-nmt")
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
        if cfg["testing"].haskey("sacrebleu"):
            sacrebleu["remove_whitespace"] = cfg["testing"]["sacrebleu"]\
                .get("remove_whitespace", True)
            sacrebleu["tokenize"] = cfg["testing"]["sacrebleu"]\
                .get("tokenize", "13a")
    else:
        beam_size = 1
        beam_alpha = -1
        postprocess = True
        bpe_type = "subword-nmt"
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}

    if not sys.stdin.isatty():
        # input file given
+13 −1
Original line number Diff line number Diff line
@@ -115,6 +115,16 @@ class TrainManager:
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # eval options
        test_config = config["testing"]
        self.bpe_type = test_config.get("bpe_type", "subword-nmt")
        self.sacrebleu = {"remove_whitespace": True, "tokenizer": "13a"}
        if "sacrebleu" in config["testing"].keys():
            self.sacrebleu["remove_whitespace"] = test_config["sacrebleu"] \
                .get("remove_whitespace", True)
            self.sacrebleu["tokenize"] = test_config["sacrebleu"] \
                .get("tokenize", "13a")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
@@ -378,7 +388,9 @@ class TrainManager:
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type,
                            postprocess=True # always remove BPE for validation
                            postprocess=True,   # always remove BPE for validation
                            bpe_type=self.bpe_type, # "subword-nmt" or "sentencepiece"
                            sacrebleu=self.sacrebleu    # sacrebleu options
                        )

                    self.tb_writer.add_scalar("valid/valid_loss",
+46 −6
Original line number Diff line number Diff line
import unittest
from test.unit.test_helpers import TensorTestCase

from joeynmt.metrics import token_accuracy
from joeynmt.metrics import chrf, bleu, token_accuracy


class TestMetrics(TensorTestCase):

    def test_chrf_without_whitespace(self):
        hyp1 = ["t est"]
        ref1 = ["tez t"]
        score1 = chrf(hyp1, ref1, remove_whitespace=True)
        hyp2 = ["test"]
        ref2 = ["tezt"]
        score2 = chrf(hyp2, ref2, remove_whitespace=True)
        self.assertAlmostEqual(score1, score2)
        self.assertAlmostEqual(score1, 0.271, places=3)

    def test_chrf_with_whitespace(self):
        hyp = ["これはテストです。"]
        ref = ["これは テストです。"]
        score = chrf(hyp, ref, remove_whitespace=False)
        self.assertAlmostEqual(score, 0.558, places=3)

    def test_bleu_13a(self):
        hyp = ["this is a test."]
        ref = ["this is a tezt."]
        score = bleu(hyp, ref, tokenize="13a")
        self.assertAlmostEqual(score, 42.729, places=3)

    def test_bleu_ja_mecab(self):
        try:
            hyp = ["これはテストです。"]
            ref = ["あれがテストです。"]
            score = bleu(hyp, ref, tokenize="ja-mecab")
            self.assertAlmostEqual(score, 39.764, places=3)
        except ModuleNotFoundError as e:
            raise unittest.SkipTest(f"{e} Skip.")

    def test_token_acc_level_char(self):
        hyp = ["test"]
        ref = ["tezt"]
        level = "char"
        acc = token_accuracy(hyp, ref, level)
        self.assertEqual(acc, 75)
        # if len(hyp) > len(ref)
        hyp = [list("tests")]
        ref = [list("tezt")]
        #level = "char"
        score = token_accuracy(hyp, ref)
        self.assertEqual(score, 60.0)

        # if len(hyp) < len(ref)
        hyp = [list("test")]
        ref = [list("tezts")]
        #level = "char"
        score = token_accuracy(hyp, ref)
        self.assertEqual(score, 75.0)