Loading fairseq/fairseq_cli/generate.py +9 −8 Original line number Diff line number Diff line Loading @@ -182,7 +182,7 @@ def _main(cfg: DictConfig, output_file): if tokenizer is not None: x = tokenizer.decode(x) return x ## knn saving code ## knn saving code - mostly taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore: print('keytype being saved:', cfg.task.knn_keytype) if cfg.task.knn_start > -1: Loading Loading @@ -263,6 +263,7 @@ def _main(cfg: DictConfig, output_file): if "net_input" not in sample: continue ## For processing in parallel ### taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore and to_skip > 0: num_samples = sample['target'].shape[0] if to_skip - num_samples > 0: Loading @@ -270,7 +271,7 @@ def _main(cfg: DictConfig, output_file): target_tokens = utils.strip_pad(sample['target'], tgt_dict.pad()).int().cpu() start_pos += len(target_tokens) continue ### for i, sample_id in enumerate(sample['id'].tolist()): if to_skip > 0: to_skip -= 1 Loading Loading @@ -339,7 +340,7 @@ def _main(cfg: DictConfig, output_file): target_tokens = ( utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() ) ## knn saving code ## knn saving code - taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore: hypo = hypos[i][0] num_items = len(hypo['tokens']) Loading Loading @@ -395,7 +396,7 @@ def _main(cfg: DictConfig, output_file): if cfg.generation.score_reference: continue ## error analysis knnmt: save knns, vals and probs ## error analysis knnmt: save knns, vals and probs - taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.knnmt and cfg.task.save_knns: to_save_objects.append( { Loading Loading @@ -543,7 +544,7 @@ def _main(cfg: DictConfig, output_file): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) ### # taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.knn_start > -1 and knn_num_samples_proc == cfg.knn_proc: break if cfg.task.save_knn_subset and total_saved >= cfg.task.save_knn_subset_num: Loading @@ -559,7 +560,7 @@ def _main(cfg: DictConfig, output_file): for fidx in range(len([cfg.task.trained_index])): faiss_indices[fidx].add_with_ids(keys, addids) #print(f"loop time {time.time()-knn_start_loop}s") ### #print(idx) #if idx == 0: # break Loading @@ -575,7 +576,7 @@ def _main(cfg: DictConfig, output_file): if cfg.task.knn_q2gpu: index_ivf.quantizer = quantizer del quantizer_gpu ### taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore: if cfg.task.knn_start > -1: dstore_keys = dstore_keys[:total_saved] Loading @@ -599,7 +600,7 @@ def _main(cfg: DictConfig, output_file): if cfg.task.knnmt and cfg.task.save_knns: pickle.dump(to_save_objects, open(cfg.task.save_knns_filename, "wb")) ### logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( Loading Loading
fairseq/fairseq_cli/generate.py +9 −8 Original line number Diff line number Diff line Loading @@ -182,7 +182,7 @@ def _main(cfg: DictConfig, output_file): if tokenizer is not None: x = tokenizer.decode(x) return x ## knn saving code ## knn saving code - mostly taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore: print('keytype being saved:', cfg.task.knn_keytype) if cfg.task.knn_start > -1: Loading Loading @@ -263,6 +263,7 @@ def _main(cfg: DictConfig, output_file): if "net_input" not in sample: continue ## For processing in parallel ### taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore and to_skip > 0: num_samples = sample['target'].shape[0] if to_skip - num_samples > 0: Loading @@ -270,7 +271,7 @@ def _main(cfg: DictConfig, output_file): target_tokens = utils.strip_pad(sample['target'], tgt_dict.pad()).int().cpu() start_pos += len(target_tokens) continue ### for i, sample_id in enumerate(sample['id'].tolist()): if to_skip > 0: to_skip -= 1 Loading Loading @@ -339,7 +340,7 @@ def _main(cfg: DictConfig, output_file): target_tokens = ( utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() ) ## knn saving code ## knn saving code - taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore: hypo = hypos[i][0] num_items = len(hypo['tokens']) Loading Loading @@ -395,7 +396,7 @@ def _main(cfg: DictConfig, output_file): if cfg.generation.score_reference: continue ## error analysis knnmt: save knns, vals and probs ## error analysis knnmt: save knns, vals and probs - taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.knnmt and cfg.task.save_knns: to_save_objects.append( { Loading Loading @@ -543,7 +544,7 @@ def _main(cfg: DictConfig, output_file): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) ### # taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.knn_start > -1 and knn_num_samples_proc == cfg.knn_proc: break if cfg.task.save_knn_subset and total_saved >= cfg.task.save_knn_subset_num: Loading @@ -559,7 +560,7 @@ def _main(cfg: DictConfig, output_file): for fidx in range(len([cfg.task.trained_index])): faiss_indices[fidx].add_with_ids(keys, addids) #print(f"loop time {time.time()-knn_start_loop}s") ### #print(idx) #if idx == 0: # break Loading @@ -575,7 +576,7 @@ def _main(cfg: DictConfig, output_file): if cfg.task.knn_q2gpu: index_ivf.quantizer = quantizer del quantizer_gpu ### taken from https://github.com/urvashik/knnmt/blob/master/fairseq_cli/generate.py if cfg.task.save_knn_dstore: if cfg.task.knn_start > -1: dstore_keys = dstore_keys[:total_saved] Loading @@ -599,7 +600,7 @@ def _main(cfg: DictConfig, output_file): if cfg.task.knnmt and cfg.task.save_knns: pickle.dump(to_save_objects, open(cfg.task.save_knns_filename, "wb")) ### logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( Loading