Loading attention_visualization.py +7 −16 Original line number Diff line number Diff line Loading @@ -3,15 +3,14 @@ from transformers import BertConfig, BertTokenizer, BertModel import seaborn as sns import matplotlib.pyplot as plt model_type = 'sts-bert-large' config = BertConfig.from_pretrained(model_type) config = BertConfig.from_pretrained('./models/stsb-bert-large/0_BERT') config.output_attentions = True model = BertModel.from_pretrained(model_type, config=config).to('cpu') tokenizer = BertTokenizer.from_pretrained(model_type) model = BertModel.from_pretrained('./models/stsb-bert-large/0_BERT', config=config).to('cpu') tokenizer = BertTokenizer.from_pretrained('./models/stsb-bert-large/0_BERT') text = 'A dog standing in the water' text = 'US drone strike kills 10 in Pakistan' tok = tokenizer.tokenize(text) pos = 2 pos = 3 ids = torch.tensor(tokenizer.convert_tokens_to_ids(tok)).unsqueeze(0).to('cpu') with torch.no_grad(): Loading @@ -24,17 +23,9 @@ seqlen = len(attentions) attentions_pos = attentions[pos] cols = 2 rows = int(heads / cols) fig, axes = plt.subplots(rows, cols, figsize=(14, 30)) axes = axes.flat print(f'Attention weights for token {tok[pos]}') for i, att in enumerate(attentions_pos): sns.heatmap(att, vmin=0, vmax=1, ax=axes[i], xticklabels=tok) axes[i].set_title(f'head {i}') axes[i].set_ylabel('layers') avg_attention = attentions_pos.mean(dim=0) sns.heatmap(avg_attention, vmin=0, vmax=1, xticklabels=tok) plt.show() Loading
attention_visualization.py +7 −16 Original line number Diff line number Diff line Loading @@ -3,15 +3,14 @@ from transformers import BertConfig, BertTokenizer, BertModel import seaborn as sns import matplotlib.pyplot as plt model_type = 'sts-bert-large' config = BertConfig.from_pretrained(model_type) config = BertConfig.from_pretrained('./models/stsb-bert-large/0_BERT') config.output_attentions = True model = BertModel.from_pretrained(model_type, config=config).to('cpu') tokenizer = BertTokenizer.from_pretrained(model_type) model = BertModel.from_pretrained('./models/stsb-bert-large/0_BERT', config=config).to('cpu') tokenizer = BertTokenizer.from_pretrained('./models/stsb-bert-large/0_BERT') text = 'A dog standing in the water' text = 'US drone strike kills 10 in Pakistan' tok = tokenizer.tokenize(text) pos = 2 pos = 3 ids = torch.tensor(tokenizer.convert_tokens_to_ids(tok)).unsqueeze(0).to('cpu') with torch.no_grad(): Loading @@ -24,17 +23,9 @@ seqlen = len(attentions) attentions_pos = attentions[pos] cols = 2 rows = int(heads / cols) fig, axes = plt.subplots(rows, cols, figsize=(14, 30)) axes = axes.flat print(f'Attention weights for token {tok[pos]}') for i, att in enumerate(attentions_pos): sns.heatmap(att, vmin=0, vmax=1, ax=axes[i], xticklabels=tok) axes[i].set_title(f'head {i}') axes[i].set_ylabel('layers') avg_attention = attentions_pos.mean(dim=0) sns.heatmap(avg_attention, vmin=0, vmax=1, xticklabels=tok) plt.show()