Commit 093b9126 authored by Aileen Reichelt's avatar Aileen Reichelt
Browse files

Adapt attention visualization to our model and purpose

parent bc0008f5
Loading
Loading
Loading
Loading
+7 −16
Original line number Diff line number Diff line
@@ -3,15 +3,14 @@ from transformers import BertConfig, BertTokenizer, BertModel
import seaborn as sns
import matplotlib.pyplot as plt

model_type = 'sts-bert-large'
config = BertConfig.from_pretrained(model_type)
config = BertConfig.from_pretrained('./models/stsb-bert-large/0_BERT')
config.output_attentions = True
model = BertModel.from_pretrained(model_type, config=config).to('cpu')
tokenizer = BertTokenizer.from_pretrained(model_type)
model = BertModel.from_pretrained('./models/stsb-bert-large/0_BERT', config=config).to('cpu')
tokenizer = BertTokenizer.from_pretrained('./models/stsb-bert-large/0_BERT')

text = 'A dog standing in the water'
text = 'US drone strike kills 10 in Pakistan'
tok = tokenizer.tokenize(text)
pos = 2
pos = 3

ids = torch.tensor(tokenizer.convert_tokens_to_ids(tok)).unsqueeze(0).to('cpu')
with torch.no_grad():
@@ -24,17 +23,9 @@ seqlen = len(attentions)

attentions_pos = attentions[pos]

cols = 2
rows = int(heads / cols)

fig, axes = plt.subplots(rows, cols, figsize=(14, 30))
axes = axes.flat
print(f'Attention weights for token {tok[pos]}')

for i, att in enumerate(attentions_pos):
    sns.heatmap(att, vmin=0, vmax=1, ax=axes[i], xticklabels=tok)
    axes[i].set_title(f'head {i}')
    axes[i].set_ylabel('layers')

avg_attention = attentions_pos.mean(dim=0)
sns.heatmap(avg_attention, vmin=0, vmax=1, xticklabels=tok)

plt.show()