Commit 25ad02ba authored by jakob-forstmann's avatar jakob-forstmann
Browse files

solve task 6+7

parent 255c8a98
Loading
Loading
Loading
Loading
+25 −16
Original line number Diff line number Diff line
@@ -3,7 +3,8 @@ import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier,plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

def generate_datasets(random_state):
@@ -52,16 +53,18 @@ def set_axis(ax,x,y):
    ax.set_xticks(())
    ax.set_yticks(())
    
def plot_results(optimal_classifiers,default_classifiers,*datasets):
def plot_results(classifiers,*datasets):
    figure = plt.figure(figsize=(27, 11)) 
    position_idx =1
    plotted_datasets_name = ["x_train","x_test"]
    plotted_datasets_name = [ "X_train with maxium depth","X_test with maxium depth",
                              "X_train with default values","X_test with default values",
                              "X_train with a random forest","X_test with a random forest"]
    dataset_count = 0
    for x_train,x_test,labels_train,labels_test in zip(*datasets):
    for pair_of_cls,x_train,x_test,labels_train,labels_test in zip(classifiers,*datasets):
        xx,yy = create_meshgrid(x_train)
        cm = plt.cm.RdBu
        cm_bright = ListedColormap(['#FF0000', '#0000FF'])
        ax = plt.subplot(3,5,position_idx)
        ax = plt.subplot(3,7,position_idx)
        if position_idx == 1:
            ax.set_title("Input data")
        # Plot the training points
@@ -72,22 +75,25 @@ def plot_results(optimal_classifiers,default_classifiers,*datasets):
                edgecolors='k')
        set_axis(ax,xx,yy)
        position_idx+=1
        opt_cls = optimal_classifiers[dataset_count]
        cls = default_classifiers[dataset_count]
        for cls in [opt_cls,cls]:    
            for name,data,labels in zip(plotted_datasets_name,[x_train,x_test],[labels_train,labels_test]):
                ax = plt.subplot(3,5,position_idx)
        cls_name_idx = 0
        for cls in pair_of_cls:
            for data,labels in zip([x_train,x_test],[labels_train,labels_test]):
                ax = plt.subplot(3,7,position_idx)
                Z= cls.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
                Z = Z.reshape(xx.shape)
                ax.contourf(xx, yy, Z, cmap=cm, alpha=.8)
                ax.scatter(data[:, 0], data[:, 1], c=labels, cmap=cm_bright,edgecolors='k')
                set_axis(ax,xx,yy)
                if dataset_count ==0:
                name = plotted_datasets_name[cls_name_idx]
                if cls_name_idx in (0,1):
                    name+=" "+str(cls.tree_.max_depth)
                # first or second column or entire first row 
                if cls_name_idx in(0,1) or dataset_count==0:
                    ax.set_title(name)
                #ax.text(xx.max() - .3, yy.min() + .3, ('%.2f' % score).lstrip('0'),
                #    size=15, horizontalalignment='right')
                position_idx += 1
                cls_name_idx+=1
        dataset_count+=1

if __name__ =="__main__":
    datasets = get_train_test_datasets()
    optimal_classifiers = [ DecisionTreeClassifier(max_depth=5),DecisionTreeClassifier(max_depth=4),DecisionTreeClassifier(max_depth=2)]
@@ -96,10 +102,13 @@ if __name__ =="__main__":
    print("default classifiers")
    default_classifiers = [DecisionTreeClassifier() for _ in range(0,3)]
    default_classifiers = train_classifiers(default_classifiers,datasets)
    combined_classifiers = list(zip(optimal_classifiers,default_classifiers))
    plot_results(optimal_classifiers,default_classifiers,*datasets)
    random_forest = [RandomForestClassifier() for _ in range(0,3)]
    print("random forest classifier")
    train_classifiers(random_forest,datasets)
    combined_classifiers =list(zip(optimal_classifiers,default_classifiers,random_forest))
    plot_results(combined_classifiers,*datasets)
    #plot_tree(optimal_classifiers[0],filled=True)
    plt.tight_layout()
    plt.show()