You signed in with another tab or window.Reload to refresh your session.You signed out in another tab or window.Reload to refresh your session.You switched accounts on another tab or window.Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: dev/_downloads/plot_confusion_matrix.ipynb
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -24,7 +24,7 @@
24
24
"execution_count":null,
25
25
"cell_type":"code",
26
26
"source": [
27
-
"print(__doc__)\n\nimport itertools\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import svm, datasets\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import confusion_matrix\n\n# import some data to play with\niris = datasets.load_iris()\nX = iris.data\ny = iris.target\nclass_names = iris.target_names\n\n# Split the data into a training set and a test set\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\n# Run classifier, using a model that is too regularized (C too low) to see\n# the impact on the results\nclassifier = svm.SVC(kernel='linear', C=0.01)\ny_pred = classifier.fit(X_train, y_train).predict(X_test)\n\n\ndef plot_confusion_matrix(cm, classes,\n normalize=False,\n title='Confusion matrix',\n cmap=plt.cm.Blues):\n \"\"\"\n This function prints and plots the confusion matrix.\n Normalization can be applied by setting `normalize=True`.\n \"\"\"\n plt.imshow(cm, interpolation='nearest', cmap=cmap)\n plt.title(title)\n plt.colorbar()\n tick_marks = np.arange(len(classes))\n plt.xticks(tick_marks, classes, rotation=45)\n plt.yticks(tick_marks, classes)\n\n if normalize:\n cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n print(\"Normalized confusion matrix\")\n else:\n print('Confusion matrix, without normalization')\n\n print(cm)\n\n thresh = cm.max() / 2.\n for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n plt.text(j, i, cm[i, j],\n horizontalalignment=\"center\",\n color=\"white\" if cm[i, j] > thresh else \"black\")\n\n plt.tight_layout()\n plt.ylabel('True label')\n plt.xlabel('Predicted label')\n\n# Compute confusion matrix\ncnf_matrix = confusion_matrix(y_test, y_pred)\nnp.set_printoptions(precision=2)\n\n# Plot non-normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names,\n title='Confusion matrix, without normalization')\n\n# Plot normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,\n title='Normalized confusion matrix')\n\nplt.show()"
27
+
"print(__doc__)\n\nimport itertools\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import svm, datasets\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import confusion_matrix\n\n# import some data to play with\niris = datasets.load_iris()\nX = iris.data\ny = iris.target\nclass_names = iris.target_names\n\n# Split the data into a training set and a test set\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\n# Run classifier, using a model that is too regularized (C too low) to see\n# the impact on the results\nclassifier = svm.SVC(kernel='linear', C=0.01)\ny_pred = classifier.fit(X_train, y_train).predict(X_test)\n\n\ndef plot_confusion_matrix(cm, classes,\n normalize=False,\n title='Confusion matrix',\n cmap=plt.cm.Blues):\n \"\"\"\n This function prints and plots the confusion matrix.\n Normalization can be applied by setting `normalize=True`.\n \"\"\"\n if normalize:\n cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n print(\"Normalized confusion matrix\")\n else:\n print('Confusion matrix, without normalization')\n\n print(cm)\n\n plt.imshow(cm, interpolation='nearest', cmap=cmap)\n plt.title(title)\n plt.colorbar()\n tick_marks = np.arange(len(classes))\n plt.xticks(tick_marks, classes, rotation=45)\n plt.yticks(tick_marks, classes)\n\n fmt = '.2f' if normalize else 'd'\n thresh = cm.max() / 2.\n for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n plt.text(j, i, format(cm[i, j], fmt),\n horizontalalignment=\"center\",\n color=\"white\" if cm[i, j] > thresh else \"black\")\n\n plt.tight_layout()\n plt.ylabel('True label')\n plt.xlabel('Predicted label')\n\n# Compute confusion matrix\ncnf_matrix = confusion_matrix(y_test, y_pred)\nnp.set_printoptions(precision=2)\n\n# Plot non-normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names,\n title='Confusion matrix, without normalization')\n\n# Plot normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,\n title='Normalized confusion matrix')\n\nplt.show()"