Commit c0950c14 authored by kuehner's avatar kuehner
Browse files

Upload New File

parent 0a1f3b2e
Loading
Loading
Loading
Loading
+134 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 25 10:37:54 2021

Created on Thu Mar 25 10:37:54 2021

This python file contains the evaluation of our study results.
The systems results are compared with our gold annotations.
For this evaluation, we use accuracy, precision, recall and f1
provided by scikitlearn.

Please make sure that this file is stored in the same directory as 
"human_needs_assigner.py", all the necessary variables and functions
from it are imported
@author: SWP-Group
"""
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from human_needs_assigner import *

def cm_analysis(y_true, y_pred, filename, labels, ymap=None, figsize=(10,10)):
    """
    Generate matrix plot of confusion matrix with annotations.
    The plot image is saved.
    
    Parameters:
    ----------
      y_true:    list
          true label of the data, with shape (nsamples,)
      y_pred:    list
          prediction of the data, with shape (nsamples,)
      filename:  str
          filename of figure file to save
      labels:    list
          string array, name the order of class labels in the confusion matrix.
                 use `clf.classes_` if using scikit-learn models.
                 with shape (nclass,).
      ymap:      dict
          any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Important: original y_true, y_pred and labels must align.
      figsize:   the size of the figure plotted.
    """
    if ymap is not None:
        y_pred = [ymap[yi] for yi in y_pred]
        y_true = [ymap[yi] for yi in y_true]
        labels = [ymap[yi] for yi in labels]
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'True'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(cm, annot=annot, fmt='', ax=ax)
    plt.savefig(filename)



# Accuracies
acc_maslow = accuracy_score(maslow_gold, maslow_needs)
acc_reiss = accuracy_score(reiss_gold, reiss_needs)
print("Accuracies:")
print("The system accuracy for Maslow categories is: {:.2f}%".format(acc_maslow*100))
print("The system accuracy for Reiss motives is: {:.2f}%".format(acc_reiss*100))

print("------------------------------------------------------------------------")

print("Precision")
prec_maslow_macro = precision_score(maslow_gold, maslow_needs, average="macro")
prec_reiss_macro = precision_score(reiss_gold, reiss_needs, average="macro")
prec_maslow_micro = precision_score(maslow_gold, maslow_needs, average="micro")
prec_reiss_micro = precision_score(reiss_gold, reiss_needs, average="micro")
print("The macro-averaged precision for Maslow predictions is: {:.2f}%".format(prec_maslow_macro*100))
print("The macro-precision for Reiss predictions is: {:.2f}%".format(prec_reiss_macro*100))
print("The micro-averaged precision for Maslow predictions is: {:.2f}%".format(prec_maslow_micro*100))
print("The micro-avergaedprecision for Reiss predictions is: {:.2f}%".format(prec_reiss_micro*100))

print("------------------------------------------------------------------------")

print("Recall:")
recall_maslow_macro = recall_score(maslow_gold, maslow_needs, average="macro")
recall_reiss_macro = recall_score(reiss_gold, reiss_needs, average='macro')
recall_maslow_micro = recall_score(maslow_gold, maslow_needs, average="micro")
recall_reiss_micro = recall_score(reiss_gold, reiss_needs, average='micro')
print("The macro-averaged recall for Maslow predictions is: {:.2f}%".format(recall_maslow_macro*100))
print("The macro-averaged recall for Reiss predictions is: {:.2f}%".format(recall_reiss_macro*100))
print("The micro-averaged recall for Maslow predictions is: {:.2f}%".format(recall_maslow_micro*100))
print("The micro-averaged recall for Reiss predictions is: {:.2f}%".format(recall_reiss_micro*100))

print("-------------------------------------------------------------------------")

print("F1 Score:")
f1_maslow_macro = f1_score(maslow_gold, maslow_needs, average="macro")
f1_reiss_macro = f1_score(reiss_gold, reiss_needs, average='macro')
f1_maslow_micro = f1_score(maslow_gold, maslow_needs, average="micro")
f1_reiss_micro = f1_score(reiss_gold, reiss_needs, average='micro')
print("The macro-averaged f1-score for Maslow predictions is: {:.2f}%".format(f1_maslow_macro*100))
print("The macro-averaged f1-score for Reiss predictions is: {:.2f}%".format(f1_reiss_macro*100))
print("The micro-averaged f1-score for Maslow predictions is: {:.2f}%".format(f1_maslow_micro*100))
print("The micro-averaged f1-score for Reiss predictions is: {:.2f}%".format(f1_reiss_micro*100))

print("-------------------------------------------------------------------------")


print("Classification report Maslow:")
maslow_report = classification_report(maslow_gold, maslow_needs)
reiss_report = classification_report(reiss_gold, reiss_needs)
print(maslow_report)
print("Classification report Reiss:")
print(reiss_report)

print("------------------------------------------------------------------------")
cm_analysis(maslow_gold, maslow_needs, "maslow.png", maslow_human_needs)
cm_analysis(reiss_gold, reiss_needs, "reiss.png", reiss_motives)