Commit a3f0bcac authored by Mayumi Ohta's avatar Mayumi Ohta
Browse files

simplify gradient accumulation logic

parent 6ac79a69
Loading
Loading
Loading
Loading
+170 −181
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ import logging
import os
import queue

import math
#import math
import numpy as np

import torch
@@ -147,7 +147,7 @@ class TrainManager:
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)
        self.current_batch_multiplier = self.batch_multiplier
        #self.current_batch_multiplier = self.batch_multiplier

        # generation
        self.max_output_length = train_config.get("max_output_length", None)
@@ -159,7 +159,7 @@ class TrainManager:
            self.loss.cuda()

        # initialize accumalted batch loss (needed for batch_multiplier)
        self.norm_batch_loss_accumulated = 0
        #self.norm_batch_loss_accumulated = 0
        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
@@ -292,10 +292,29 @@ class TrainManager:
                                    batch_type=self.batch_type,
                                    train=True, shuffle=self.shuffle)

        # For last batch in epoch batch_multiplier needs to be adjusted
        # to fit the number of leftover training examples
        leftover_batch_size = len(
            train_data) % (self.batch_multiplier * self.batch_size)
        #################################################################
        # simplify accumulation logic:
        #################################################################
        # for epoch in range(epochs):
        #     self.model.zero_grad()
        #     epoch_loss = 0.0
        #     batch_loss = 0.0
        #     for i, batch in enumerate(iter(train_iter)):
        #
        #         # - gradients accumulated automatically!
        #         # - loss.backward() inside _train_step()
        #         epoch_loss += self._train_step(inputs)
        #
        #         if (i + 1) % self.batch_multiplier == 0:
        #             self.optimizer.step()     # update!
        #             self.model.zero_grad()    # reset gradients
        #             self.steps += 1           # increment counter
        #
        #             epoch_loss += batch_loss  # add batch loss
        #             batch_loss = 0            # reset batch loss
        #
        #     # leftovers are just ignored.
        #################################################################

        for epoch_no in range(self.epochs):
            logger.info("EPOCH %d", epoch_no + 1)
@@ -309,55 +328,41 @@ class TrainManager:
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            self.current_batch_multiplier = self.batch_multiplier
            self.optimizer.zero_grad()
            count = self.current_batch_multiplier - 1
            self.model.zero_grad()
            epoch_loss = 0
            batch_loss = 0

            for i, batch in enumerate(iter(train_iter)):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672

                # Set current_batch_mutliplier to fit
                # number of leftover examples for last batch in epoch
                # Only works if batch_type == sentence
                if self.batch_type == "sentence":
                    if self.batch_multiplier > 1 and i == len(train_iter) - \
                            math.ceil(leftover_batch_size / self.batch_size):
                        self.current_batch_multiplier = math.ceil(
                            leftover_batch_size / self.batch_size)
                        count = self.current_batch_multiplier - 1

                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(
                    batch, update=update, count=count)

                # Only save finaly computed batch_loss of full batch
                if update:
                    self.tb_writer.add_scalar("train/train_batch_loss",
                                              batch_loss, self.steps)
                # get batch loss
                batch_loss += self._train_step(batch)

                count = self.batch_multiplier if update else count
                count -= 1
                # update!
                if (i + 1) % self.batch_multiplier == 0:
                    # clip gradients (in-place)
                    if self.clip_grad_fun is not None:
                        self.clip_grad_fun(params=self.model.parameters())

                # Only add complete batch_loss of full mini-batch to epoch_loss
                if update:
                    epoch_loss += batch_loss.detach().cpu().numpy()
                    # make gradient step
                    self.optimizer.step()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    # decay lr
                    if self.scheduler is not None \
                            and self.scheduler_step_at == "step":
                        self.scheduler.step()

                    # reset gradients
                    self.model.zero_grad()

                    # increment step counter
                    self.steps += 1

                    # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    if self.steps % self.logging_freq == 0:
                        self.tb_writer.add_scalar("train/train_batch_loss",
                                                  batch_loss, self.steps)
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.total_tokens - start_tokens
                        logger.info(
@@ -370,8 +375,74 @@ class TrainManager:
                        total_valid_duration = 0
                        start_tokens = self.total_tokens

                    # Only add complete loss of full mini-batch to epoch_loss
                    epoch_loss += batch_loss    # accumulate epoch_loss
                    batch_loss = 0              # rest batch_loss

                    # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    if self.steps % self.validation_freq == 0:
                        valid_duration = self._validate(valid_data, epoch_no)
                        total_valid_duration += valid_duration

                if self.stop:
                    break
            if self.stop:
                logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            logger.info('Training ended after %3d epochs.', epoch_no + 1)
        logger.info('Best validation result (greedy) at step '
                         '%8d: %6.2f %s.', self.best_ckpt_iteration,
                         self.best_ckpt_score,
                         self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer

    def _train_step(self, batch: Batch) -> Tensor:
        """
        Train the model on one batch: Compute the loss, make a gradient step.

        :param batch: training batch
        :return: loss for batch (sum)
        """
        # reactivate training
        self.model.train()

        # get loss
        batch_loss = self.model.get_loss_for_batch(
            batch=batch, loss_function=self.loss)

        # normalize batch loss
        if self.normalization == "batch":
            normalizer = batch.nseqs
        elif self.normalization == "tokens":
            normalizer = batch.ntokens
        elif self.normalization == "none":
            normalizer = 1
        else:
            raise NotImplementedError(
                "Only normalize by 'batch' or 'tokens' "
                "or summation of loss 'none' implemented")

        norm_batch_loss = batch_loss / normalizer

        if self.batch_multiplier > 1:
            norm_batch_loss = norm_batch_loss / self.batch_multiplier

        # accumulate gradients
        norm_batch_loss.backward()

        # increment token counter
        self.total_tokens += batch.ntokens

        return norm_batch_loss.item()

    def _validate(self, valid_data, epoch_no):
        valid_start_time = time.time()

        valid_score, valid_loss, valid_ppl, valid_sources, \
@@ -393,12 +464,9 @@ class TrainManager:
                sacrebleu=self.sacrebleu    # sacrebleu options
            )

                    self.tb_writer.add_scalar("valid/valid_loss",
                                              valid_loss, self.steps)
                    self.tb_writer.add_scalar("valid/valid_score",
                                              valid_score, self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl",
                                              valid_ppl, self.steps)
        self.tb_writer.add_scalar("valid/valid_loss", valid_loss, self.steps)
        self.tb_writer.add_scalar("valid/valid_score", valid_score, self.steps)
        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl, self.steps)

        if self.early_stopping_metric == "loss":
            ckpt_score = valid_loss
@@ -411,8 +479,7 @@ class TrainManager:
        if self.is_best(ckpt_score):
            self.best_ckpt_score = ckpt_score
            self.best_ckpt_iteration = self.steps
                        logger.info(
                            'Hooray! New best validation result [%s]!',
            logger.info('Hooray! New best validation result [%s]!',
                        self.early_stopping_metric)
            if self.ckpt_queue.maxsize > 0:
                logger.info("Saving new checkpoint.")
@@ -438,7 +505,6 @@ class TrainManager:
        )

        valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
        logger.info(
            'Validation result (greedy) at epoch %3d, '
            'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
@@ -460,84 +526,7 @@ class TrainManager:
                    self.model_dir, self.steps),
                tb_writer=self.tb_writer, steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            logger.info('Training ended after %3d epochs.', epoch_no + 1)
        logger.info('Best validation result (greedy) at step '
                         '%8d: %6.2f %s.', self.best_ckpt_iteration,
                         self.best_ckpt_score,
                         self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer

    def _train_batch(self, batch: Batch, update: bool = True,
                     count: int = 1) -> Tensor:
        """
        Train the model on one batch: Compute the loss, make a gradient step.

        :param batch: training batch
        :param update: if False, only store gradient. if True also make update
        :param count: number of portions (batch_size) left before update
        :return: loss for batch (sum)
        """
        batch_loss = self.model.get_loss_for_batch(
            batch=batch, loss_function=self.loss)

        # normalize batch loss
        if self.normalization == "batch":
            normalizer = batch.nseqs
        elif self.normalization == "tokens":
            normalizer = batch.ntokens
        elif self.normalization == "none":
            normalizer = 1
        else:
            raise NotImplementedError(
                "Only normalize by 'batch' or 'tokens' "
                "or summation of loss 'none' implemented")

        norm_batch_loss = batch_loss / normalizer

        if update:
            if self.current_batch_multiplier > 1:
                norm_batch_loss = self.norm_batch_loss_accumulated + \
                    norm_batch_loss
                norm_batch_loss = norm_batch_loss / \
                    self.current_batch_multiplier if \
                    self.normalization != "none" else \
                    norm_batch_loss

            norm_batch_loss.backward()

            if self.clip_grad_fun is not None:
                # clip gradients (in-place)
                self.clip_grad_fun(params=self.model.parameters())

            # make gradient step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # increment step counter
            self.steps += 1

        else:
            if count == self.current_batch_multiplier - 1:
                self.norm_batch_loss_accumulated = norm_batch_loss
            else:
                # accumulate loss of current batch_size * batch_multiplier loss
                self.norm_batch_loss_accumulated += norm_batch_loss
        # increment token counter
        self.total_tokens += batch.ntokens

        return norm_batch_loss
        return valid_duration

    def _add_report(self, valid_score: float, valid_ppl: float,
                    valid_loss: float, eval_metric: str,