Commit 739fab74 authored by hubert's avatar hubert
Browse files

mark knnmt code

parent 3a91accc
Loading
Loading
Loading
Loading
+9 −8
Original line number Diff line number Diff line
@@ -182,7 +182,7 @@ def _main(cfg: DictConfig, output_file):
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x
    ## knn saving code
    ## knn saving code      - mostly taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py
    if cfg.task.save_knn_dstore:
        print('keytype being saved:', cfg.task.knn_keytype)
        if cfg.task.knn_start > -1:
@@ -263,6 +263,7 @@ def _main(cfg: DictConfig, output_file):
        if "net_input" not in sample:
            continue
        ## For processing in parallel
        ### taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py
        if cfg.task.save_knn_dstore and to_skip > 0:
            num_samples = sample['target'].shape[0]
            if to_skip - num_samples > 0:
@@ -270,7 +271,7 @@ def _main(cfg: DictConfig, output_file):
                target_tokens = utils.strip_pad(sample['target'], tgt_dict.pad()).int().cpu()
                start_pos += len(target_tokens)
                continue

        ###
            for i, sample_id in enumerate(sample['id'].tolist()):
                if to_skip > 0:
                    to_skip -= 1
@@ -339,7 +340,7 @@ def _main(cfg: DictConfig, output_file):
                target_tokens = (
                    utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
                )
            ## knn saving code
            ## knn saving code - taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py
            if cfg.task.save_knn_dstore:
                hypo = hypos[i][0]
                num_items = len(hypo['tokens'])
@@ -395,7 +396,7 @@ def _main(cfg: DictConfig, output_file):
            if cfg.generation.score_reference:
                continue

            ## error analysis knnmt: save knns, vals and probs
            ## error analysis knnmt: save knns, vals and probs -   taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py
            if cfg.task.knnmt and cfg.task.save_knns:
                to_save_objects.append(
                        {
@@ -543,7 +544,7 @@ def _main(cfg: DictConfig, output_file):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        ### # taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py
            if cfg.task.knn_start > -1 and knn_num_samples_proc == cfg.knn_proc:
                break
            if cfg.task.save_knn_subset and total_saved >= cfg.task.save_knn_subset_num:
@@ -559,7 +560,7 @@ def _main(cfg: DictConfig, output_file):
            for fidx in range(len([cfg.task.trained_index])):
                faiss_indices[fidx].add_with_ids(keys, addids)
            #print(f"loop time {time.time()-knn_start_loop}s")

        ###
        #print(idx)
        #if idx == 0:
        #    break
@@ -575,7 +576,7 @@ def _main(cfg: DictConfig, output_file):
    if cfg.task.knn_q2gpu:
        index_ivf.quantizer = quantizer
        del quantizer_gpu

    ### taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py
    if cfg.task.save_knn_dstore:
        if cfg.task.knn_start > -1:
            dstore_keys = dstore_keys[:total_saved]
@@ -599,7 +600,7 @@ def _main(cfg: DictConfig, output_file):

    if cfg.task.knnmt and cfg.task.save_knns:
        pickle.dump(to_save_objects, open(cfg.task.save_knns_filename, "wb"))

    ###
 
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(