Quelques visuels pour évaluer le modèle

Je vous présente quelques graphique qui servent à avoir un insight dans le modèle ainsi que l’évaluation de la qualité du modèle, sa précision etc.

Confusion Matrix

Ceci graphique permet de voir el taux de faux positifs et de faux négatifs.

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

import joblib
loaded_rf = joblib.load('random_forest_model.pkl')
#load the datas
import pickle
with open('X_test.pkl', 'rb') as f:
    X_test = pickle.load(f)
with open('y_test.pkl', 'rb') as f:
    y_test = pickle.load(f)

prediction = loaded_rf.predict(X_test)
# Create confusion matrix
cm = confusion_matrix(y_test, prediction)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

Receiver Operating Characteristics

from sklearn.metrics import roc_curve, auc

# Get probability scores for the positive class
if hasattr(loaded_rf, "predict_proba"):
    proba = loaded_rf.predict_proba(X_test)[:, 1]
    fpr, tpr, _ = roc_curve(y_test, proba)
    roc_auc = auc(fpr, tpr)
    
    # Plot ROC curve
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.show()

Precision recall curve

from sklearn.metrics import precision_recall_curve, average_precision_score

# Get probabilities and calculate precision-recall
if hasattr(loaded_rf, "predict_proba"):
    proba = loaded_rf.predict_proba(X_test)[:, 1]
    precision, recall, _ = precision_recall_curve(y_test, proba)
    avg_precision = average_precision_score(y_test, proba)
    
    # Plot precision-recall curve
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='blue', lw=2, label=f'Precision-Recall curve (AP = {avg_precision:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.show()

Classification report

from sklearn.metrics import classification_report
import pandas as pd

# Get classification report as a dictionary
report = classification_report(y_test, prediction, output_dict=True)
report_df = pd.DataFrame(report).transpose()

# Plot classification report
plt.figure(figsize=(10, 6))
sns.heatmap(report_df.iloc[:-1, :].drop(['support'], axis=1), annot=True, cmap='Blues')
plt.title('Classification Report')
plt.tight_layout()
plt.show()

Feature importance

if hasattr(loaded_rf, "feature_importances_"):
    # Create a dataframe of feature importances
    feature_importance = pd.DataFrame({
        'feature': X_test.columns,
        'importance': loaded_rf.feature_importances_
    }).sort_values('importance', ascending=False)
    
    # Plot feature importances
    plt.figure(figsize=(10, 8))
    sns.barplot(x='importance', y='feature', data=feature_importance)
    plt.title('Feature Importance')
    plt.tight_layout()
    plt.show()

Overall Performance Metrics

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(y_test, prediction)
precision = precision_score(y_test, prediction, average='weighted')
recall = recall_score(y_test, prediction, average='weighted')
f1 = f1_score(y_test, prediction, average='weighted')

metrics = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
    'Value': [accuracy, precision, recall, f1]
})

# Plot metrics
plt.figure(figsize=(10, 6))
sns.barplot(x='Value', y='Metric', data=metrics, hue='Metric',palette='viridis')
plt.title('Model Performance Metrics')
plt.xlim(0, 1)
for i, v in enumerate(metrics['Value']):
    plt.text(v + 0.01, i, f'{v:.4f}', va='center')
plt.tight_layout()
plt.show()

Retour en haut