Commit 0d635085 authored by Mayumi Ohta's avatar Mayumi Ohta
Browse files

fix typo

parent fe2fbbe3
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -88,7 +88,6 @@ class Model(nn.Module):
            return encoder_output, encoder_hidden, None, None

        elif return_type == "decode":

            outputs, hidden, att_probs, att_vectors = self._decode(
                trg_input=kwargs["trg_input"],
                encoder_output=kwargs["encoder_output"],
+11 −3
Original line number Diff line number Diff line
@@ -204,7 +204,10 @@ def test(cfg_file,
    cfg = load_config(cfg_file)

    if len(logger.handlers) == 0:
        make_logger(f'{cfg["training"]["model_dir"]}/test.log')
         log_file = None
         if os.path.exists(cfg["training"]["model_dir"]):
             log_file = f'{cfg["training"]["model_dir"]}/test.log'
         make_logger(log_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")
@@ -282,7 +285,9 @@ def test(cfg_file,
        if data_set is None:
            continue

        logger.info(f"Decoding on {data_set_name} set...")
        dataset_filepath = cfg["data"][data_set_name] + "." + cfg["data"]["trg"]
        logger.info(f"Decoding on {data_set_name} set ({dataset_filepath})...")

        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
@@ -380,7 +385,10 @@ def translate(cfg_file, ckpt: str, output_path: str = None) -> None:
    cfg = load_config(cfg_file)

    #logger = make_logger()
    make_logger(f'{cfg["training"]["model_dir"]}/translation.log')
    log_file = None
    if os.path.exists(cfg["training"]["model_dir"]):
        log_file = f'{cfg["training"]["model_dir"]}/translation.log'
    make_logger(log_file)

    # when checkpoint is not specified, take oldest from model dir
    if ckpt is None:
+8 −8
Original line number Diff line number Diff line
@@ -179,10 +179,6 @@ class TrainManager:
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
                                                        opt_level=fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.n_gpu > 1:
            self.model = _DataParallel(self.model)

        # initialize accumalted batch loss (needed for batch_multiplier)
        #self.norm_batch_loss_accumulated = 0
        # initialize training statistics
@@ -209,6 +205,10 @@ class TrainManager:
                                      reset_scheduler=reset_scheduler,
                                      reset_optimizer=reset_optimizer)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.n_gpu > 1:
            self.model = _DataParallel(self.model)

    def _save_checkpoint(self) -> None:
        """
        Save the model's current parameters and the training state to a
@@ -306,7 +306,7 @@ class TrainManager:

        # fp16
        if self.fp16 and model_checkpoint.get("amp_state", None) is not None:
            amp.load_state_dict(checkpoint['amp_state'])
            amp.load_state_dict(model_checkpoint['amp_state'])

    # pylint: disable=unnecessary-comprehension
    # pylint: disable=too-many-branches
@@ -335,14 +335,14 @@ class TrainManager:
        #
        #         # gradient accumulation:
        #         # loss.backward() inside _train_step()
        #         epoch_loss += self._train_step(inputs)
        #         batch_loss += self._train_step(inputs)
        #
        #         if (i + 1) % self.batch_multiplier == 0:
        #             self.optimizer.step()     # update!
        #             self.model.zero_grad()    # reset gradients
        #             self.steps += 1           # increment counter
        #
        #             epoch_loss += batch_loss  # add batch loss
        #             epoch_loss += batch_loss  # accumulate batch loss
        #             batch_loss = 0            # reset batch loss
        #
        #     # leftovers are just ignored.
@@ -411,7 +411,7 @@ class TrainManager:
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.total_tokens - start_tokens
                        logger.info(
                            "Epoch %3d Step: %8d Batch Loss: %12.6f "
                            "Epoch %3d, Step: %8d, Batch Loss: %12.6f, "
                            "Tokens per Sec: %8.0f, Lr: %.6f",
                            epoch_no + 1, self.steps, batch_loss,
                            elapsed_tokens / elapsed,
+7 −7
Original line number Diff line number Diff line
@@ -69,9 +69,9 @@ class TestData(TensorTestCase):
            b = Batch(torch_batch=b, pad_index=self.pad_index)
            if total_samples == 0:
                self.assertTensorEqual(b.src, expected_src0)
                self.assertTensorEqual(b.src_lengths, expected_src0_len)
                self.assertTensorEqual(b.src_length, expected_src0_len)
                self.assertTensorEqual(b.trg, expected_trg0)
                self.assertTensorEqual(b.trg_lengths, expected_trg0_len)
                self.assertTensorEqual(b.trg_length, expected_trg0_len)
            total_samples += b.nseqs
            self.assertLessEqual(b.nseqs, batch_size)
        self.assertEqual(total_samples, len(self.train_data))
@@ -117,18 +117,18 @@ class TestData(TensorTestCase):

            # test the sorting by src length
            self.assertEqual(type(b), Batch)
            before_sort = b.src_lengths
            b.sort_by_src_lengths()
            after_sort = b.src_lengths
            before_sort = b.src_length
            b.sort_by_src_length()
            after_sort = b.src_length
            self.assertTensorEqual(torch.sort(before_sort, descending=True)[0],
                                   after_sort)
            self.assertEqual(type(b), Batch)

            if total_samples == 0:
                self.assertTensorEqual(b.src, expected_src0)
                self.assertTensorEqual(b.src_lengths, expected_src0_len)
                self.assertTensorEqual(b.src_length, expected_src0_len)
                self.assertTensorEqual(b.trg, expected_trg0)
                self.assertTensorEqual(b.trg_lengths, expected_trg0_len)
                self.assertTensorEqual(b.trg_length, expected_trg0_len)
            total_samples += b.nseqs
            self.assertLessEqual(b.nseqs, batch_size)
        self.assertEqual(total_samples, len(self.dev_data))
+1 −1
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ class TestMetrics(TensorTestCase):
            ref = ["あれがテストです。"]
            score = bleu(hyp, ref, tokenize="ja-mecab")
            self.assertAlmostEqual(score, 39.764, places=3)
        except ModuleNotFoundError as e:
        except Exception as e:
            raise unittest.SkipTest(f"{e} Skip.")

    def test_token_acc_level_char(self):