From 9d23727eb0433b76771d517ae91db56b5f87b0d2 Mon Sep 17 00:00:00 2001
From: anastasiafilia <95204313+anastasiafilia@users.noreply.github.com>
Date: Sun, 2 Feb 2025 23:16:08 +1100
Subject: [PATCH 1/3] Add files via upload
---
Coding test - Anastasia .ipynb | 425 +++++++++++++++++++++++++++++++++
1 file changed, 425 insertions(+)
create mode 100644 Coding test - Anastasia .ipynb
diff --git a/Coding test - Anastasia .ipynb b/Coding test - Anastasia .ipynb
new file mode 100644
index 0000000..f3974e0
--- /dev/null
+++ b/Coding test - Anastasia .ipynb
@@ -0,0 +1,425 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "acb117e5",
+ "metadata": {},
+ "source": [
+ "# Ore Proximity Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6120f7f0",
+ "metadata": {},
+ "source": [
+ "A consultant is using metal abundance changes to predict proximity to an orebody. Samples are classified as proximal (A) or distal (B) based on Euclidean distance to a wireframe model of the orebody.\n",
+ "\n",
+ "1. Can we use the same geochemical data and labels to generate a predictive model for future drill holes which can label samples on whether they are in class A or class B?\n",
+ "2. More data has been acquired since the geochemist completed her work - can we predict labels onto these data points (labelled “?”).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b04ccd8c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from sklearn.model_selection import train_test_split, RandomizedSearchCV \n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
+ "from sklearn.preprocessing import StandardScaler, LabelEncoder \n",
+ "import seaborn as sns\n",
+ "from scipy.stats import randint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9e051ac5",
+ "metadata": {},
+ "source": [
+ "## Step 1: Data - QA/QC"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "28e738e3",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "data = pd.read_csv('data/data_for_distribution.csv')\n",
+ "data.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "89a97680",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "data.info()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9701bd91",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# Check how many NaN values are in each column\n",
+ "nan_values = data.isna().sum()\n",
+ "\n",
+ "# Print the number of NaN values per column\n",
+ "print(\"Number of NaN values per column:\")\n",
+ "print(nan_values)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "adf9369f",
+ "metadata": {},
+ "source": [
+ "Looking at the data, I saw that there are some there are some QA/QC issues.\n",
+ "These included\n",
+ "- Au data type is object\n",
+ "- missing values\n",
+ "- values below the detection limit (<0.005)\n",
+ "- unsuitable values (-999)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "943cad16",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "elements = ['As', 'Au', 'Pb', 'Fe', 'Mo', 'Cu', 'S', 'Zn']\n",
+ "\n",
+ "# Replace '<0.005' with half of the detection limit and '-999' with NaN\n",
+ "data[elements] = data[elements].replace({'<0.005': 0.005 / 2, '-999': np.nan})\n",
+ "\n",
+ "# Convert the 'Au' column to float64\n",
+ "data['Au'] = pd.to_numeric(data['Au'], errors='coerce')\n",
+ "\n",
+ "# Drop NaN values\n",
+ "data_cleaned = data.dropna(subset=elements)\n",
+ "\n",
+ "data_cleaned.to_csv(\"./data/cleaned_data.csv\", index=False)\n",
+ "data_cleaned.head()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6c5c78d1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Some checks:\n",
+ "\n",
+ "# Check the data type of 'Au' column\n",
+ "print(data['Au'].dtype)\n",
+ "\n",
+ "nan_check = data_cleaned.isna().sum()\n",
+ "\n",
+ "# Print NaN values if any exist\n",
+ "if nan_check.any():\n",
+ " print(\"NaN values in cleaned data:\")\n",
+ " print(nan_check[nan_check > 0], \"\\n\")\n",
+ "\n",
+ "for element in elements:\n",
+ " # Check for values equal to '-999'\n",
+ " if (data_cleaned[element] == -999).any():\n",
+ " print(f\"Warning: '{element}' contains '-999' values.\")\n",
+ " \n",
+ " # Check for values equal to '<0.005'\n",
+ " if (data_cleaned[element] == '<0.005').any():\n",
+ " print(f\"Warning: '{element}' contains '<0.005' values.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "53c5522b",
+ "metadata": {},
+ "source": [
+ "We can do other data checks but I am assuming these are the only issues and moving on"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2d2aef0f",
+ "metadata": {},
+ "source": [
+ "## Step 2: Modelling\n",
+ "### Aim 1: Generate a predictive model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "411d810e",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Check the distribution of classes\n",
+ "class_counts = data_cleaned['Class'].value_counts()\n",
+ "print(class_counts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e945df09",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Mostly from https://www.datacamp.com/tutorial/random-forests-classifier-python\n",
+ "\n",
+ "# Separate labeled and unlabeled data\n",
+ "labeled_data = data_cleaned[data_cleaned['Class'].isin(['A', 'B'])]\n",
+ "unlabeled_data = data_cleaned[data_cleaned['Class'] == '?']\n",
+ "\n",
+ "# Features and target\n",
+ "X = labeled_data[elements]\n",
+ "y = labeled_data['Class']\n",
+ "\n",
+ "# Encode labels\n",
+ "le = LabelEncoder()\n",
+ "y_encoded = le.fit_transform(y)\n",
+ "\n",
+ "# Standardize features\n",
+ "scaler = StandardScaler()\n",
+ "X_scaled = scaler.fit_transform(X)\n",
+ "\n",
+ "# Train-test split\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)\n",
+ "\n",
+ "\n",
+ "#Hyperparameter tuning \n",
+ "param_dist = {'n_estimators': randint(50,500),\n",
+ " 'max_depth': randint(1,20)}\n",
+ "\n",
+ "rf = RandomForestClassifier()\n",
+ "\n",
+ "rand_search = RandomizedSearchCV(rf, \n",
+ " param_distributions = param_dist, \n",
+ " n_iter=5, \n",
+ " cv=5)\n",
+ "rand_search.fit(X_train, y_train)\n",
+ "\n",
+ "best_rf = rand_search.best_estimator_\n",
+ "\n",
+ "# Print the best hyperparameters\n",
+ "print('Best hyperparameters:', rand_search.best_params_)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "27ad8b29",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# Evaluate model\n",
+ "y_pred = best_rf.predict(X_test)\n",
+ "print(\"Classification Report:\")\n",
+ "print(classification_report(y_test, y_pred))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "53d68eca",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Confusion Matrix:\")\n",
+ "cm = confusion_matrix(y_test, y_pred)\n",
+ "ConfusionMatrixDisplay(confusion_matrix=cm).plot();\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "05ea49ab",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# #optional extra\n",
+ "# # Checking to see how the data was scaled\n",
+ "\n",
+ "# # Visualize histograms before scaling for 'Fe' (example)\n",
+ "# plt.figure(figsize=(12, 6))\n",
+ "# plt.subplot(1, 2, 1)\n",
+ "# plt.hist(X['Fe'], bins=50, color='blue', alpha=0.7)\n",
+ "# plt.title('Before Scaling: Fe')\n",
+ "\n",
+ "# # Visualize histograms after scaling for 'Fe'\n",
+ "# plt.subplot(1, 2, 2)\n",
+ "# plt.hist(X_scaled[:, X.columns.get_loc('Fe')], bins=50, color='green', alpha=0.7)\n",
+ "# plt.title('After Scaling: Fe')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "705d2ffd",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# #optional extra\n",
+ "# Feature Importance\n",
+ "# feature_importances = best_rf.feature_importances_\n",
+ "# feature_importance_df = pd.DataFrame({'Element': elements, 'Importance': feature_importances})\n",
+ "# feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)\n",
+ "\n",
+ "# plt.figure(figsize=(10, 6))\n",
+ "# sns.barplot(x='Importance', y='Element', data=feature_importance_df, color='steelblue')\n",
+ "# plt.title('Feature Importance for Predicting Class A vs Class B')\n",
+ "# plt.tight_layout()\n",
+ "# plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "989eba46",
+ "metadata": {},
+ "source": [
+ "### Aim 2: Predict labels "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "41f2f5e1",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "unlabeled_data_copy = unlabeled_data.copy()\n",
+ "\n",
+ "# Predict unlabeled data \n",
+ "X_new = unlabeled_data[elements]\n",
+ "X_new_scaled = scaler.transform(X_new)\n",
+ "unlabeled_data.loc[:, 'Predicted_Class'] = le.inverse_transform(best_rf.predict(X_new_scaled))\n",
+ "\n",
+ "print(\"Predictions on Unlabeled Data:\")\n",
+ "print(unlabeled_data[['Unique_ID', 'holeid', 'Predicted_Class']].head(70))\n",
+ "unlabeled_data.to_csv('./data/predictions_on_unlabeled_data.csv', index=False)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23730f2b",
+ "metadata": {},
+ "source": [
+ "Results available on: predictions_on_unlabeled_data.csv"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "182b0d3e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Filter the rows where the original 'Class' column is '?'\n",
+ "unlabeled_data_question = unlabeled_data[unlabeled_data['Class'] == '?']\n",
+ "\n",
+ "# Count how many were predicted to be 'A' and 'B'\n",
+ "class_counts_from_question = unlabeled_data_question['Predicted_Class'].value_counts()\n",
+ "\n",
+ "# Print out the prediction counts for 'A' and 'B' from '?'\n",
+ "print(f\"Predictions for '?' class:\")\n",
+ "print(class_counts_from_question)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0eb795e5",
+ "metadata": {},
+ "source": [
+ "### Observations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8bcc086",
+ "metadata": {},
+ "source": [
+ "- The model performs well overall, with high accuracy and strong performance for Class A.\n",
+ "- Class B predictions are less accurate.\n",
+ "- The imbalance in the dataset (more Class A samples) is probably why the model performs better for Class A \n",
+ "- Pb has the highest feature importance (makes sense as this is a lead deposit)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8dc4640",
+ "metadata": {},
+ "source": [
+ "### Improvements: \n",
+ "- Handle missing values more effectively (instead of dropping).\n",
+ "- Perform additional model validation (e.g., cross-validation, ROC curves).\n",
+ "- Experiment with other models (e.g., SVM, Gradient Boosting) for comparison."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7cabd73e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From 283af0b2c68df9843d0dbf266e06d86186ed55ca Mon Sep 17 00:00:00 2001
From: anastasiafilia <95204313+anastasiafilia@users.noreply.github.com>
Date: Sun, 2 Feb 2025 23:32:03 +1100
Subject: [PATCH 2/3] Delete Coding test - Anastasia .ipynb
---
Coding test - Anastasia .ipynb | 425 ---------------------------------
1 file changed, 425 deletions(-)
delete mode 100644 Coding test - Anastasia .ipynb
diff --git a/Coding test - Anastasia .ipynb b/Coding test - Anastasia .ipynb
deleted file mode 100644
index f3974e0..0000000
--- a/Coding test - Anastasia .ipynb
+++ /dev/null
@@ -1,425 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "acb117e5",
- "metadata": {},
- "source": [
- "# Ore Proximity Model"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6120f7f0",
- "metadata": {},
- "source": [
- "A consultant is using metal abundance changes to predict proximity to an orebody. Samples are classified as proximal (A) or distal (B) based on Euclidean distance to a wireframe model of the orebody.\n",
- "\n",
- "1. Can we use the same geochemical data and labels to generate a predictive model for future drill holes which can label samples on whether they are in class A or class B?\n",
- "2. More data has been acquired since the geochemist completed her work - can we predict labels onto these data points (labelled “?”).\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b04ccd8c",
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from sklearn.model_selection import train_test_split, RandomizedSearchCV \n",
- "from sklearn.ensemble import RandomForestClassifier\n",
- "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
- "from sklearn.preprocessing import StandardScaler, LabelEncoder \n",
- "import seaborn as sns\n",
- "from scipy.stats import randint"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9e051ac5",
- "metadata": {},
- "source": [
- "## Step 1: Data - QA/QC"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "28e738e3",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "data = pd.read_csv('data/data_for_distribution.csv')\n",
- "data.head()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "89a97680",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "data.info()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9701bd91",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "# Check how many NaN values are in each column\n",
- "nan_values = data.isna().sum()\n",
- "\n",
- "# Print the number of NaN values per column\n",
- "print(\"Number of NaN values per column:\")\n",
- "print(nan_values)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "adf9369f",
- "metadata": {},
- "source": [
- "Looking at the data, I saw that there are some there are some QA/QC issues.\n",
- "These included\n",
- "- Au data type is object\n",
- "- missing values\n",
- "- values below the detection limit (<0.005)\n",
- "- unsuitable values (-999)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "943cad16",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "elements = ['As', 'Au', 'Pb', 'Fe', 'Mo', 'Cu', 'S', 'Zn']\n",
- "\n",
- "# Replace '<0.005' with half of the detection limit and '-999' with NaN\n",
- "data[elements] = data[elements].replace({'<0.005': 0.005 / 2, '-999': np.nan})\n",
- "\n",
- "# Convert the 'Au' column to float64\n",
- "data['Au'] = pd.to_numeric(data['Au'], errors='coerce')\n",
- "\n",
- "# Drop NaN values\n",
- "data_cleaned = data.dropna(subset=elements)\n",
- "\n",
- "data_cleaned.to_csv(\"./data/cleaned_data.csv\", index=False)\n",
- "data_cleaned.head()\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6c5c78d1",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Some checks:\n",
- "\n",
- "# Check the data type of 'Au' column\n",
- "print(data['Au'].dtype)\n",
- "\n",
- "nan_check = data_cleaned.isna().sum()\n",
- "\n",
- "# Print NaN values if any exist\n",
- "if nan_check.any():\n",
- " print(\"NaN values in cleaned data:\")\n",
- " print(nan_check[nan_check > 0], \"\\n\")\n",
- "\n",
- "for element in elements:\n",
- " # Check for values equal to '-999'\n",
- " if (data_cleaned[element] == -999).any():\n",
- " print(f\"Warning: '{element}' contains '-999' values.\")\n",
- " \n",
- " # Check for values equal to '<0.005'\n",
- " if (data_cleaned[element] == '<0.005').any():\n",
- " print(f\"Warning: '{element}' contains '<0.005' values.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "53c5522b",
- "metadata": {},
- "source": [
- "We can do other data checks but I am assuming these are the only issues and moving on"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "2d2aef0f",
- "metadata": {},
- "source": [
- "## Step 2: Modelling\n",
- "### Aim 1: Generate a predictive model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "411d810e",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "# Check the distribution of classes\n",
- "class_counts = data_cleaned['Class'].value_counts()\n",
- "print(class_counts)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e945df09",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "# Mostly from https://www.datacamp.com/tutorial/random-forests-classifier-python\n",
- "\n",
- "# Separate labeled and unlabeled data\n",
- "labeled_data = data_cleaned[data_cleaned['Class'].isin(['A', 'B'])]\n",
- "unlabeled_data = data_cleaned[data_cleaned['Class'] == '?']\n",
- "\n",
- "# Features and target\n",
- "X = labeled_data[elements]\n",
- "y = labeled_data['Class']\n",
- "\n",
- "# Encode labels\n",
- "le = LabelEncoder()\n",
- "y_encoded = le.fit_transform(y)\n",
- "\n",
- "# Standardize features\n",
- "scaler = StandardScaler()\n",
- "X_scaled = scaler.fit_transform(X)\n",
- "\n",
- "# Train-test split\n",
- "X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)\n",
- "\n",
- "\n",
- "#Hyperparameter tuning \n",
- "param_dist = {'n_estimators': randint(50,500),\n",
- " 'max_depth': randint(1,20)}\n",
- "\n",
- "rf = RandomForestClassifier()\n",
- "\n",
- "rand_search = RandomizedSearchCV(rf, \n",
- " param_distributions = param_dist, \n",
- " n_iter=5, \n",
- " cv=5)\n",
- "rand_search.fit(X_train, y_train)\n",
- "\n",
- "best_rf = rand_search.best_estimator_\n",
- "\n",
- "# Print the best hyperparameters\n",
- "print('Best hyperparameters:', rand_search.best_params_)\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "27ad8b29",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "# Evaluate model\n",
- "y_pred = best_rf.predict(X_test)\n",
- "print(\"Classification Report:\")\n",
- "print(classification_report(y_test, y_pred))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "53d68eca",
- "metadata": {},
- "outputs": [],
- "source": [
- "print(\"Confusion Matrix:\")\n",
- "cm = confusion_matrix(y_test, y_pred)\n",
- "ConfusionMatrixDisplay(confusion_matrix=cm).plot();\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "05ea49ab",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "# #optional extra\n",
- "# # Checking to see how the data was scaled\n",
- "\n",
- "# # Visualize histograms before scaling for 'Fe' (example)\n",
- "# plt.figure(figsize=(12, 6))\n",
- "# plt.subplot(1, 2, 1)\n",
- "# plt.hist(X['Fe'], bins=50, color='blue', alpha=0.7)\n",
- "# plt.title('Before Scaling: Fe')\n",
- "\n",
- "# # Visualize histograms after scaling for 'Fe'\n",
- "# plt.subplot(1, 2, 2)\n",
- "# plt.hist(X_scaled[:, X.columns.get_loc('Fe')], bins=50, color='green', alpha=0.7)\n",
- "# plt.title('After Scaling: Fe')\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "705d2ffd",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "# #optional extra\n",
- "# Feature Importance\n",
- "# feature_importances = best_rf.feature_importances_\n",
- "# feature_importance_df = pd.DataFrame({'Element': elements, 'Importance': feature_importances})\n",
- "# feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)\n",
- "\n",
- "# plt.figure(figsize=(10, 6))\n",
- "# sns.barplot(x='Importance', y='Element', data=feature_importance_df, color='steelblue')\n",
- "# plt.title('Feature Importance for Predicting Class A vs Class B')\n",
- "# plt.tight_layout()\n",
- "# plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "989eba46",
- "metadata": {},
- "source": [
- "### Aim 2: Predict labels "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "41f2f5e1",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "unlabeled_data_copy = unlabeled_data.copy()\n",
- "\n",
- "# Predict unlabeled data \n",
- "X_new = unlabeled_data[elements]\n",
- "X_new_scaled = scaler.transform(X_new)\n",
- "unlabeled_data.loc[:, 'Predicted_Class'] = le.inverse_transform(best_rf.predict(X_new_scaled))\n",
- "\n",
- "print(\"Predictions on Unlabeled Data:\")\n",
- "print(unlabeled_data[['Unique_ID', 'holeid', 'Predicted_Class']].head(70))\n",
- "unlabeled_data.to_csv('./data/predictions_on_unlabeled_data.csv', index=False)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "23730f2b",
- "metadata": {},
- "source": [
- "Results available on: predictions_on_unlabeled_data.csv"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "182b0d3e",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Filter the rows where the original 'Class' column is '?'\n",
- "unlabeled_data_question = unlabeled_data[unlabeled_data['Class'] == '?']\n",
- "\n",
- "# Count how many were predicted to be 'A' and 'B'\n",
- "class_counts_from_question = unlabeled_data_question['Predicted_Class'].value_counts()\n",
- "\n",
- "# Print out the prediction counts for 'A' and 'B' from '?'\n",
- "print(f\"Predictions for '?' class:\")\n",
- "print(class_counts_from_question)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0eb795e5",
- "metadata": {},
- "source": [
- "### Observations"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e8bcc086",
- "metadata": {},
- "source": [
- "- The model performs well overall, with high accuracy and strong performance for Class A.\n",
- "- Class B predictions are less accurate.\n",
- "- The imbalance in the dataset (more Class A samples) is probably why the model performs better for Class A \n",
- "- Pb has the highest feature importance (makes sense as this is a lead deposit)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e8dc4640",
- "metadata": {},
- "source": [
- "### Improvements: \n",
- "- Handle missing values more effectively (instead of dropping).\n",
- "- Perform additional model validation (e.g., cross-validation, ROC curves).\n",
- "- Experiment with other models (e.g., SVM, Gradient Boosting) for comparison."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7cabd73e",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
From 60cba53bc93b816170edec790bdacbdd58d94658 Mon Sep 17 00:00:00 2001
From: anastasiafilia <95204313+anastasiafilia@users.noreply.github.com>
Date: Sun, 2 Feb 2025 23:32:20 +1100
Subject: [PATCH 3/3] Add files via upload
---
Coding test - Anastasia.ipynb | 849 ++++++++++++++++++++++++++++++++++
1 file changed, 849 insertions(+)
create mode 100644 Coding test - Anastasia.ipynb
diff --git a/Coding test - Anastasia.ipynb b/Coding test - Anastasia.ipynb
new file mode 100644
index 0000000..c1f5dce
--- /dev/null
+++ b/Coding test - Anastasia.ipynb
@@ -0,0 +1,849 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "acb117e5",
+ "metadata": {},
+ "source": [
+ "# Ore Proximity Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6120f7f0",
+ "metadata": {},
+ "source": [
+ "A consultant is using metal abundance changes to predict proximity to an orebody. Samples are classified as proximal (A) or distal (B) based on Euclidean distance to a wireframe model of the orebody.\n",
+ "\n",
+ "1. Can we use the same geochemical data and labels to generate a predictive model for future drill holes which can label samples on whether they are in class A or class B?\n",
+ "2. More data has been acquired since the geochemist completed her work - can we predict labels onto these data points (labelled “?”).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b04ccd8c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from sklearn.model_selection import train_test_split, RandomizedSearchCV \n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
+ "from sklearn.preprocessing import StandardScaler, LabelEncoder \n",
+ "import seaborn as sns\n",
+ "from scipy.stats import randint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9e051ac5",
+ "metadata": {},
+ "source": [
+ "## Step 1: Data - QA/QC"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "28e738e3",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Unique_ID | \n",
+ " holeid | \n",
+ " from | \n",
+ " to | \n",
+ " As | \n",
+ " Au | \n",
+ " Pb | \n",
+ " Fe | \n",
+ " Mo | \n",
+ " Cu | \n",
+ " S | \n",
+ " Zn | \n",
+ " Class | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " A04812 | \n",
+ " SOLVE003 | \n",
+ " 561 | \n",
+ " 571.0 | \n",
+ " NaN | \n",
+ " 0.066 | \n",
+ " 1031.00 | \n",
+ " 61380.0 | \n",
+ " 138.2000 | \n",
+ " 3.600 | \n",
+ " 3586.0000 | \n",
+ " 43.6000 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " A03356 | \n",
+ " SOLVE003 | \n",
+ " 571 | \n",
+ " 581.0 | \n",
+ " NaN | \n",
+ " 0.152 | \n",
+ " 1982.00 | \n",
+ " 50860.0 | \n",
+ " 75.4000 | \n",
+ " 4.800 | \n",
+ " 1822.0000 | \n",
+ " 36.4000 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " A04764 | \n",
+ " SOLVE003 | \n",
+ " 581 | \n",
+ " 591.0 | \n",
+ " NaN | \n",
+ " 0.068 | \n",
+ " 1064.80 | \n",
+ " 57940.0 | \n",
+ " 29.2000 | \n",
+ " 3.000 | \n",
+ " 740.4000 | \n",
+ " 36.6000 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " A04626 | \n",
+ " SOLVE003 | \n",
+ " 591 | \n",
+ " 601.0 | \n",
+ " NaN | \n",
+ " 0.074 | \n",
+ " 891.60 | \n",
+ " 48620.0 | \n",
+ " 63.0000 | \n",
+ " 4.200 | \n",
+ " 820.8000 | \n",
+ " 39.6000 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " A05579 | \n",
+ " SOLVE003 | \n",
+ " 601 | \n",
+ " 611.0 | \n",
+ " NaN | \n",
+ " 0.043125 | \n",
+ " 801.25 | \n",
+ " 51025.0 | \n",
+ " 56.0625 | \n",
+ " 4.875 | \n",
+ " 745.6875 | \n",
+ " 32.3125 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Unique_ID holeid from to As Au Pb Fe Mo \\\n",
+ "0 A04812 SOLVE003 561 571.0 NaN 0.066 1031.00 61380.0 138.2000 \n",
+ "1 A03356 SOLVE003 571 581.0 NaN 0.152 1982.00 50860.0 75.4000 \n",
+ "2 A04764 SOLVE003 581 591.0 NaN 0.068 1064.80 57940.0 29.2000 \n",
+ "3 A04626 SOLVE003 591 601.0 NaN 0.074 891.60 48620.0 63.0000 \n",
+ "4 A05579 SOLVE003 601 611.0 NaN 0.043125 801.25 51025.0 56.0625 \n",
+ "\n",
+ " Cu S Zn Class \n",
+ "0 3.600 3586.0000 43.6000 A \n",
+ "1 4.800 1822.0000 36.4000 A \n",
+ "2 3.000 740.4000 36.6000 A \n",
+ "3 4.200 820.8000 39.6000 A \n",
+ "4 4.875 745.6875 32.3125 A "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data = pd.read_csv('data/data_for_distribution.csv')\n",
+ "data.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "89a97680",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "RangeIndex: 4771 entries, 0 to 4770\n",
+ "Data columns (total 13 columns):\n",
+ " # Column Non-Null Count Dtype \n",
+ "--- ------ -------------- ----- \n",
+ " 0 Unique_ID 4771 non-null object \n",
+ " 1 holeid 4771 non-null object \n",
+ " 2 from 4771 non-null int64 \n",
+ " 3 to 4771 non-null float64\n",
+ " 4 As 3268 non-null float64\n",
+ " 5 Au 4765 non-null object \n",
+ " 6 Pb 4756 non-null float64\n",
+ " 7 Fe 4709 non-null float64\n",
+ " 8 Mo 4741 non-null float64\n",
+ " 9 Cu 4746 non-null float64\n",
+ " 10 S 4761 non-null float64\n",
+ " 11 Zn 4762 non-null float64\n",
+ " 12 Class 4771 non-null object \n",
+ "dtypes: float64(8), int64(1), object(4)\n",
+ "memory usage: 484.7+ KB\n"
+ ]
+ }
+ ],
+ "source": [
+ "data.info()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "9701bd91",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of NaN values per column:\n",
+ "Unique_ID 0\n",
+ "holeid 0\n",
+ "from 0\n",
+ "to 0\n",
+ "As 1503\n",
+ "Au 6\n",
+ "Pb 15\n",
+ "Fe 62\n",
+ "Mo 30\n",
+ "Cu 25\n",
+ "S 10\n",
+ "Zn 9\n",
+ "Class 0\n",
+ "dtype: int64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Check how many NaN values are in each column\n",
+ "nan_values = data.isna().sum()\n",
+ "\n",
+ "# Print the number of NaN values per column\n",
+ "print(\"Number of NaN values per column:\")\n",
+ "print(nan_values)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "adf9369f",
+ "metadata": {},
+ "source": [
+ "Looking at the data, I saw that there are some there are some QA/QC issues.\n",
+ "These included\n",
+ "- Au data type is object\n",
+ "- missing values\n",
+ "- values below the detection limit (<0.005)\n",
+ "- unsuitable values (-999)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "943cad16",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Unique_ID | \n",
+ " holeid | \n",
+ " from | \n",
+ " to | \n",
+ " As | \n",
+ " Au | \n",
+ " Pb | \n",
+ " Fe | \n",
+ " Mo | \n",
+ " Cu | \n",
+ " S | \n",
+ " Zn | \n",
+ " Class | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 257 | \n",
+ " A03720 | \n",
+ " SOLVE026 | \n",
+ " 731 | \n",
+ " 741.0 | \n",
+ " 5.0 | \n",
+ " 0.122 | \n",
+ " 1881.2 | \n",
+ " 35440.0 | \n",
+ " 133.0 | \n",
+ " 6.0 | \n",
+ " 2770.0 | \n",
+ " 43.4 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 1043 | \n",
+ " A02183 | \n",
+ " SOLVE047 | \n",
+ " 351 | \n",
+ " 361.0 | \n",
+ " 18.4 | \n",
+ " 0.290 | \n",
+ " 2150.0 | \n",
+ " 65700.0 | \n",
+ " 32.6 | \n",
+ " 146.4 | \n",
+ " 9854.0 | \n",
+ " 362.0 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 1057 | \n",
+ " A04574 | \n",
+ " SOLVE052 | \n",
+ " 411 | \n",
+ " 421.0 | \n",
+ " 6.6 | \n",
+ " 0.076 | \n",
+ " 1636.0 | \n",
+ " 57380.0 | \n",
+ " 57.4 | \n",
+ " 12.2 | \n",
+ " 1808.0 | \n",
+ " 37.8 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 1058 | \n",
+ " A04149 | \n",
+ " SOLVE052 | \n",
+ " 421 | \n",
+ " 431.0 | \n",
+ " 15.8 | \n",
+ " 0.097 | \n",
+ " 2375.6 | \n",
+ " 53980.0 | \n",
+ " 50.4 | \n",
+ " 13.2 | \n",
+ " 2878.0 | \n",
+ " 35.6 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ " | 1059 | \n",
+ " A05461 | \n",
+ " SOLVE052 | \n",
+ " 431 | \n",
+ " 441.0 | \n",
+ " 18.6 | \n",
+ " 0.046 | \n",
+ " 1514.6 | \n",
+ " 54920.0 | \n",
+ " 18.6 | \n",
+ " 13.8 | \n",
+ " 2394.0 | \n",
+ " 30.4 | \n",
+ " A | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Unique_ID holeid from to As Au Pb Fe Mo \\\n",
+ "257 A03720 SOLVE026 731 741.0 5.0 0.122 1881.2 35440.0 133.0 \n",
+ "1043 A02183 SOLVE047 351 361.0 18.4 0.290 2150.0 65700.0 32.6 \n",
+ "1057 A04574 SOLVE052 411 421.0 6.6 0.076 1636.0 57380.0 57.4 \n",
+ "1058 A04149 SOLVE052 421 431.0 15.8 0.097 2375.6 53980.0 50.4 \n",
+ "1059 A05461 SOLVE052 431 441.0 18.6 0.046 1514.6 54920.0 18.6 \n",
+ "\n",
+ " Cu S Zn Class \n",
+ "257 6.0 2770.0 43.4 A \n",
+ "1043 146.4 9854.0 362.0 A \n",
+ "1057 12.2 1808.0 37.8 A \n",
+ "1058 13.2 2878.0 35.6 A \n",
+ "1059 13.8 2394.0 30.4 A "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "elements = ['As', 'Au', 'Pb', 'Fe', 'Mo', 'Cu', 'S', 'Zn']\n",
+ "\n",
+ "# Replace '<0.005' with half of the detection limit and '-999' with NaN\n",
+ "data[elements] = data[elements].replace({'<0.005': 0.005 / 2, '-999': np.nan})\n",
+ "\n",
+ "# Convert the 'Au' column to float64\n",
+ "data['Au'] = pd.to_numeric(data['Au'], errors='coerce')\n",
+ "\n",
+ "# Drop NaN values\n",
+ "data_cleaned = data.dropna(subset=elements)\n",
+ "\n",
+ "data_cleaned.to_csv(\"./data/cleaned_data.csv\", index=False)\n",
+ "data_cleaned.head()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "6c5c78d1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "float64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Some checks:\n",
+ "\n",
+ "# Check the data type of 'Au' column\n",
+ "print(data['Au'].dtype)\n",
+ "\n",
+ "nan_check = data_cleaned.isna().sum()\n",
+ "\n",
+ "# Print NaN values if any exist\n",
+ "if nan_check.any():\n",
+ " print(\"NaN values in cleaned data:\")\n",
+ " print(nan_check[nan_check > 0], \"\\n\")\n",
+ "\n",
+ "for element in elements:\n",
+ " # Check for values equal to '-999'\n",
+ " if (data_cleaned[element] == -999).any():\n",
+ " print(f\"Warning: '{element}' contains '-999' values.\")\n",
+ " \n",
+ " # Check for values equal to '<0.005'\n",
+ " if (data_cleaned[element] == '<0.005').any():\n",
+ " print(f\"Warning: '{element}' contains '<0.005' values.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "53c5522b",
+ "metadata": {},
+ "source": [
+ "We can do other data checks but I am assuming these are the only issues and moving on"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2d2aef0f",
+ "metadata": {},
+ "source": [
+ "## Step 2: Modelling\n",
+ "### Aim 1: Generate a predictive model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "411d810e",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Class\n",
+ "A 1759\n",
+ "B 743\n",
+ "? 685\n",
+ "Name: count, dtype: int64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Check the distribution of classes\n",
+ "class_counts = data_cleaned['Class'].value_counts()\n",
+ "print(class_counts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "e945df09",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best hyperparameters: {'max_depth': 17, 'n_estimators': 466}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Mostly from https://www.datacamp.com/tutorial/random-forests-classifier-python\n",
+ "\n",
+ "# Separate labeled and unlabeled data\n",
+ "labeled_data = data_cleaned[data_cleaned['Class'].isin(['A', 'B'])]\n",
+ "unlabeled_data = data_cleaned[data_cleaned['Class'] == '?']\n",
+ "\n",
+ "# Features and target\n",
+ "X = labeled_data[elements]\n",
+ "y = labeled_data['Class']\n",
+ "\n",
+ "# Encode labels\n",
+ "le = LabelEncoder()\n",
+ "y_encoded = le.fit_transform(y)\n",
+ "\n",
+ "# Standardize features\n",
+ "scaler = StandardScaler()\n",
+ "X_scaled = scaler.fit_transform(X)\n",
+ "\n",
+ "# Train-test split\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)\n",
+ "\n",
+ "\n",
+ "#Hyperparameter tuning \n",
+ "param_dist = {'n_estimators': randint(50,500),\n",
+ " 'max_depth': randint(1,20)}\n",
+ "\n",
+ "rf = RandomForestClassifier()\n",
+ "\n",
+ "rand_search = RandomizedSearchCV(rf, \n",
+ " param_distributions = param_dist, \n",
+ " n_iter=5, \n",
+ " cv=5)\n",
+ "rand_search.fit(X_train, y_train)\n",
+ "\n",
+ "best_rf = rand_search.best_estimator_\n",
+ "\n",
+ "# Print the best hyperparameters\n",
+ "print('Best hyperparameters:', rand_search.best_params_)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "27ad8b29",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Classification Report:\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.93 0.92 0.92 363\n",
+ " 1 0.80 0.80 0.80 138\n",
+ "\n",
+ " accuracy 0.89 501\n",
+ " macro avg 0.86 0.86 0.86 501\n",
+ "weighted avg 0.89 0.89 0.89 501\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Evaluate model\n",
+ "y_pred = best_rf.predict(X_test)\n",
+ "print(\"Classification Report:\")\n",
+ "print(classification_report(y_test, y_pred))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "53d68eca",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Confusion Matrix:\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAATgAAAEGCAYAAADxD4m3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaJ0lEQVR4nO3deZRV1Zn38e+PYpJJQAYLQXFAbTARfQlq2xrHQGJ3o7ZmYdtp3sREk+hrWhPTmhWH1kW3nURN3tegwSGS1miTdiKJExJtNMuBoREFRYkQQBBkcAARqKrn/eOewotW3boH7uXee+r3Weusumffc/Z+CuRx77PP2UcRgZlZFnWodABmZuXiBGdmmeUEZ2aZ5QRnZpnlBGdmmdWx0gHk69e3LoYO6VTpMCyF1+d3q3QIlsJHbGJrbNGu1DHmxO6xbn1jUcfOmb/l8YgYuyvt7YqqSnBDh3TixceHVDoMS2HMoJGVDsFSeCFm7HIda9c38sLjg4s6tlP9n/rtcoO7oKoSnJnVgqAxmiodRFGc4MwslQCaqI0HBJzgzCy1JtyDM7MMCoJtHqKaWRYF0Oghqpllla/BmVkmBdBYI6sQOcGZWWq1cQXOCc7MUgrC1+DMLJsiYFtt5DcnODNLSzSyS4+z7jZOcGaWSgBN7sGZWVa5B2dmmZS70dcJzswyKIBtURtr5TrBmVkqgWiskcXAneDMLLWm8BDVzDLI1+DMLMNEo6/BmVkW5Vb0rY0EVxtRmlnViBBbo66orRBJXSW9KOklSQsk/UtS3lfSdElvJD/75J1zhaTFkhZJGtNWrE5wZpZaEypqa8MW4KSIOBwYCYyVdDRwOTAjIoYBM5J9JA0HxgMjgLHAJEkFs6gTnJmlkptk6FDUVrCenI3JbqdkC2AcMCUpnwKcnnweB9wXEVsiYgmwGBhdqA0nODNLKTfJUMwG9JM0O287f4eapDpJ84A1wPSIeAEYGBGrAJKfA5LD9wGW552+IilrlScZzCyVlJMMayNiVKt1RTQCIyX1Bh6UdFiBuloa8xZ87N8JzsxSayzxjb4R8a6kp8ldW1stqT4iVkmqJ9e7g1yPbUjeaYOBlYXq9RDVzFIJxLboWNRWiKT+Sc8NSXsApwCvAdOACclhE4CHk8/TgPGSukjaHxgGvFioDffgzCyV5kmGEqgHpiQzoR2AqRHxO0nPAVMlnQcsA84GiIgFkqYCC4EG4MJkiNsqJzgzSyVQSYaoETEfOKKF8nXAya2cMxGYWGwbTnBmllqtPMngBGdmqUTgZ1HNLJtykwyFH8OqFk5wZpaaF7w0s0wK5AUvzSy73IMzs0zKvRfVCc7MMslvtjezjMq9NtCzqGaWQRHyENXMsss3+ppZJuXWg/M1ODPLJL820MwyKnebiHtwZpZBfhbVzDLNyyWZWSbllkvyENXMMsrX4Mwsk3KriXiIamYZlHtUywmuXdj6kfjumQexbWsHGhvguNPe4x8ve5spP9qb5x7fEwl699vG9366jL32buDt5Z35xucPZfABWwA49H9t4jv/vqLCv0X71X/QVi772TL6DGggmuCRu/fioTv6c8CIzVx8/Qo6d22isUHcfMVgFs3rVulwq4R7cABIGgv8DKgDbo+I68vZXiV06hL86Dd/Yo/uTTRsg0tPH8bnTnqfs761hgnffxuAh27vx9037b09kdXvt4VbnlxUybAt0dggJl87iMUvd2OP7o3c/NjrzJ3Zk6//cCV33ziQ2U/14nMnvc95P1zJ9886qNLhVo12/yRD8q7DnwOnknsj9SxJ0yJiYbnarAQJ9ujeBEDDNtG4TUjQvWfT9mM+2twB1cZ/D+3O+jWdWL+mEwCbN9WxfHFX+tVvIwK698y9crN7r0bWr+5UyTCrimdRc0YDiyPiTQBJ9wHjyL20NVMaG+GiMYewcmln/uZ/r+XQIz8E4JfX782Tv+lL916N/Oi/Fm8//u1lnfn2qQfTrWcTE/55FZ85alOlQrc8Awdv5cDDNvPa3G7cetU+/Ou9b/KNq1YhBZf87bBKh1dVamWIWs4o9wGW5+2vSMp2IOl8SbMlzX5nXcGXVFetujq45clF3DNnIYvmdWPpa10B+Orlb3PPnIWcdOYGpt3ZH4C+A7Zx96yFTJr+Ohdc8xbXf3s/Nn1QG/+xZFnXbo1ceftSbr1qEB9urOOvJ6zjF1cP4h9GDecX1+zDpTcub7uSdqL5nQzFbJVWzn9ZLf128amCiMkRMSoiRvXfqzYe/2hNjz0bOfyYjcx6qucO5SeesYFnH9kTgM5dgl59c4l82Gc3M2joVt56s8tuj9U+VtcxuPL2pfzhgT788dHeAJx69vrtf2czf7snB4/8sIIRVpcAGqJDUVullTOCFcCQvP3BwMoytlcR766rY+N7ucS8ZbOY+0xPhhy0hbfe7Lz9mOcf35MhB23Zfnxj0lFd9efOvLWkM3vvu3W3x23NgktvWM7yN7rywOT+20vXre7EZ4/JXToY+VcbWbnE/xPK1xQditoKkTRE0lOSXpW0QNJ3kvJrJL0laV6yfSnvnCskLZa0SNKYtuIs5zW4WcAwSfsDbwHjgb8vY3sVsX51J37ynX1pahJNTXD837zL0ae+z7VfH8qKP3WhQwcYsM9WLk5mUF9+vge/+vHe1HWEug7BxdevoFef2hyaZ8GI0Zs45ewNvLmwK5Om52a2f/lv9fz0ssF869qV1NUFW7d04KeXDa5wpFWkdMPPBuC7ETFXUk9gjqTpyXc3RcRP8g+WNJxcHhkBDAKelHRwRLT6D6hsCS4iGiRdBDxO7jaROyNiQbnaq5QDhn/EpOmvf6r8qtuXtnj8cae9x3GnvVfmqKxYC17swZhBh7f43UVjD97N0dSGUi14GRGrgFXJ5w8kvUoL1+nzjAPui4gtwBJJi8lNZj7X2gllHSRHxCMRcXBEHBgRE8vZlpntPikmGfo1TyIm2/kt1SdpKHAE8EJSdJGk+ZLulNQnKStq4jKfn2Qws1RSLni5NiJGFTpAUg/gfuCfIuJ9SbcA1yVNXQfcAHyNIicu8znBmVkqgWhoKs3gT1Incsntnoh4ACAiVud9fxvwu2Q39cRl5edxzazmNKGitkIkCbgDeDUibswrr8877AzgleTzNGC8pC7J5OUw4MVCbbgHZ2bpRMnWgzsW+ArwsqR5SdkPgHMkjcy1xFLgAoCIWCBpKrmnoRqACwvNoIITnJmlVKqXzkTEs7R8Xe2RAudMBIqesHSCM7PUquExrGI4wZlZKoFoLNEkQ7k5wZlZau1+PTgzy6Yo3SRD2TnBmVlq4QRnZtlUHWu9FcMJzsxScw/OzDIpAhqbnODMLKM8i2pmmRR4iGpmmeVJBjPLsCi4Clv1cIIzs9Q8RDWzTMrNovpZVDPLKA9RzSyzPEQ1s0wK5ARnZtlVIyNUJzgzSykg/KiWmWWVh6hmllk1P4sq6f9RYKgdEReXJSIzq2pZeRZ19m6LwsxqRwC1nuAiYkr+vqTuEbGp/CGZWbWrlSFqm89bSDpG0kLg1WT/cEmTyh6ZmVUpEU3FbZVWzANlPwXGAOsAIuIl4PgyxmRm1S6K3CqsqFnUiFgu7ZCNG8sTjplVvaidSYZienDLJf0lEJI6S/oeyXDVzNqpEvTgJA2R9JSkVyUtkPSdpLyvpOmS3kh+9sk75wpJiyUtkjSmrTCLSXDfBC4E9gHeAkYm+2bWbqnIraAG4LsR8RfA0cCFkoYDlwMzImIYMCPZJ/luPDACGAtMklRXqIE2h6gRsRY4t63jzKwdadr1KiJiFbAq+fyBpFfJdaTGASckh00Bngb+OSm/LyK2AEskLQZGA8+11kYxs6gHSPqtpHckrZH0sKQDdv7XMrOa1nwfXDEb9JM0O287v6UqJQ0FjgBeAAYmya85CQ5IDtsHWJ532oqkrFXFTDL8Gvg5cEayPx64FziqiHPNLINS3Ae3NiJGFTpAUg/gfuCfIuL9T0xo7nBoS6EUqruYa3CKiP+IiIZku7utSs0s40p0m4ikTuSS2z0R8UBSvFpSffJ9PbAmKV8BDMk7fTCwslD9rSa4ZCajL/CUpMslDZW0n6TvA79vO3Qzy6zih6itUq6rdgfwakTcmPfVNGBC8nkC8HBe+XhJXSTtDwwDXizURqEh6hxyObg5ygvyfz3guoLRm1lmqTRjuGOBrwAvS5qXlP0AuB6YKuk8YBlwNkBELJA0FVhIbgb2wogoeE9uoWdR99/l8M0se0JQgsewIuJZWr+X5ORWzpkITCy2jaKeZJB0GDAc6JrX0K+KbcTMMqZGrsK3meAkXU3unpThwCPAF4FnASc4s/aqRhJcMbOoZ5HrLr4dEV8FDge6lDUqM6tuGXrYfnNENElqkNSL3JStb/Q1a6+ysOBlntmSegO3kZtZ3UgbU7Nmlm0lmkUtu2KeRf128vFWSY8BvSJifnnDMrOqVusJTtKRhb6LiLnlCcnMql0WenA3FPgugJNKHAuvz+/GmEEjS12tldHGLx9d6RAshaYnni9NRbV+DS4iTtydgZhZjaiSGdJi+MXPZpaeE5yZZZVKsODl7uAEZ2bp1UgPrpgVfSXpHyRdlezvK2l0+UMzs2qkKH6rtGIe1ZoEHAOck+x/QG6FXzNrr0qwHtzuUMwQ9aiIOFLS/wBExAZJncscl5lVsyronRWjmAS3LXk1VwBI6k9J3qljZrWqGoafxSgmwf1f4EFggKSJ5FYX+WFZozKz6hUZmkWNiHskzSG3ZJKA0yPCb7Y3a8+y0oOTtC/wIfDb/LKIWFbOwMysimUlwZF7g1bzy2e6AvsDi4ARZYzLzKpYZq7BRcRn8veTVUYuaOVwM7OqkfpJhoiYK+lz5QjGzGpEVnpwki7N2+0AHAm8U7aIzKy6ZWkWFeiZ97mB3DW5+8sTjpnVhCz04JIbfHtExGW7KR4zq3IiA5MMkjpGREOhpcvNrJ2qkQRX6GH75jdnzZM0TdJXJJ3ZvO2O4MysCpVwNRFJd0paI+mVvLJrJL0laV6yfSnvuyskLZa0SNKYtuov5hpcX2AduXcwNN8PF8ADRZxrZllUukmGu4CbgV99ovymiPhJfoGk4cB4cvfgDgKelHRwRDS2VnmhBDcgmUF9hY8TW7Ma6aCaWTmU6hpcRMyUNLTIw8cB90XEFmCJpMXAaOC51k4oNEStA3okW8+8z82bmbVXUeQG/STNztvOL7KFiyTNT4awfZKyfYDlecesSMpaVagHtyoiri0yGDNrL9K9VWttRIxK2cItwHVJK9eRe4Xp19hxFJkfTasKJbjKL8dpZlWpnLeJRMTq7e1ItwG/S3ZXAEPyDh0MrCxUV6Eh6sk7G6CZZVzxQ9TUJNXn7Z5Bbh4AYBowXlIXSfsDw/j4bo8WFXrx8/qdC8/Msq5Uj2pJuhc4gdy1uhXA1cAJkkaSS5FLSRb3iIgFkqYCC8k9VXVhoRlU8GsDzSytEr7ZPiLOaaH4jgLHTwQmFlu/E5yZpSJq5wK9E5yZpVcjd8I6wZlZajX/sL2ZWauc4MwskzK24KWZ2Y7cgzOzrPI1ODPLLic4M8sq9+DMLJuCUi54WVZOcGaWSiZeOmNm1ionODPLKkVtZDgnODNLp4SriZSbE5yZpeZrcGaWWX5Uy8yyyz04M8ukIt9aXw2c4MwsPSc4M8si3+hrZpmmptrIcE5wZpaO74Nrn/oP2splP1tGnwENRBM8cvdePHRHf35w61IGH7gFgO69Gtn0fh3fPvWQCkfbfl1xztMcO/zPbNi4B1/59y8DcOLhf+K8sXPYb+AGvnHTmby2vD8Avbp9xMSvTufQfdfw6IuHcOP9f1XJ0KtGu79NRNKdwF8DayLisHK1U00aG8Tkawex+OVu7NG9kZsfe525M3vyr98cuv2Y869ayaYPOlQuSOORFw7m/mdGcOW5T20ve/Ptvvzgl1/gsi/P3OHYrQ113PbIKA6o38AB9X4X+nY10oMr57+0u4CxZay/6qxf04nFL3cDYPOmOpYv7kq/+m15RwTH/+27PPVQn8oEaAC89OYg3v+w6w5lf17dh2Vren/q2I+2dmL+knq2NtTtpuhqg6K4rdLK1oOLiJmShpar/mo3cPBWDjxsM6/N7ba97LCjNrHhnY6sXNKlgpGZ7aIAauRh+4qPlSSdL2m2pNnb2FLpcEqia7dGrrx9KbdeNYgPN378f/4TT3+Xpx/qXbnAzEpETcVtbdYj3SlpjaRX8sr6Spou6Y3kZ5+8766QtFjSIklj2qq/4gkuIiZHxKiIGNWJ2u/Z1HUMrrx9KX94oA9/fLT39vIOdcGxX3qP/57Wu9VzzWpB831wJRqi3sWnL2VdDsyIiGHAjGQfScOB8cCI5JxJkgpeO6h4gsuW4NIblrP8ja48MLn/Dt8cedwHLF/chbWrOlcoNrMSiSh+a7OqmAl8cvZmHDAl+TwFOD2v/L6I2BIRS4DFwOhC9fs2kRIaMXoTp5y9gTcXdmXS9EUA/PLf6pn1h158fpyHp9Ximn98kiMOXEXvHh/x4DV3c8ejo3j/wy5c8nd/pHePzfz4/Ed54629uPTW0wD4r6vuoXuXbXTs2Mhxn1nKJbecxtLV7XuiKMUEQj9Js/P2J0fE5DbOGRgRqwAiYpWkAUn5PsDzecetSMpaVc7bRO4FTiD3C64Aro6IO8rVXjVY8GIPxgw6vMXvbrhk390cjbXmml+d0mL5zJf3b7H8rGvPLWc4tan4BLc2IkaVqFWljaScs6jnlKtuM6usMt8CslpSfdJ7qwfWJOUrgCF5xw0GVhaqyNfgzCydABqjuG3nTAMmJJ8nAA/nlY+X1EXS/sAw4MVCFfkanJmlVqoeXEuXsoDrgamSzgOWAWcDRMQCSVOBhUADcGFENBaq3wnOzNIr0Y2+BS5lndzK8ROBicXW7wRnZqlVw2NYxXCCM7N0vFySmWWVAO38BMJu5QRnZqn5zfZmlk0eoppZdhX3nGk1cIIzs9Q8i2pm2eUenJllUngW1cyyrDbymxOcmaXn20TMLLuc4MwskwJo7y9+NrNsEuEhqpllWFNtdOGc4MwsHQ9RzSzLPEQ1s+xygjOzbPLD9maWVc1v1aoBTnBmlpqvwZlZdjnBmVkmBdDkBGdmmeRJBjPLMic4M8ukABpr41EGJzgzSykgSpPgJC0FPgAagYaIGCWpL/CfwFBgKfDliNiwM/V3KEmUZta+RBS3FefEiBgZEaOS/cuBGRExDJiR7O8UJzgzS6d5FrWYbeeMA6Ykn6cAp+9sRU5wZpZe6XpwATwhaY6k85OygRGxKtdMrAIG7GyYvgZnZukVP/zsJ2l23v7kiJict39sRKyUNACYLum1ksWIE5yZpRUBjY3FHr0279paC1XFyuTnGkkPAqOB1ZLqI2KVpHpgzc6G6iGqmaVXgiGqpO6SejZ/Br4AvAJMAyYkh00AHt7ZMN2DM7P0SnOj70DgQUmQy0W/jojHJM0Cpko6D1gGnL2zDTjBmVlKuzRD+nEtEW8Ch7dQvg44eZcbwAnOzNIKiBLd6FtuTnBmlp4f1TKzTIrwawPNLMO8moiZZVW4B2dm2eQFL80sq7xkuZllVQBR/KNaFeUEZ2bpROkWvCw3JzgzSy08RDWzzKqRHpyiimZDJL0D/LnScZRBP2BtpYOwVLL6d7ZfRPTflQokPUbuz6cYayNi7K60tyuqKsFllaTZhdbEsurjv7Ns8HpwZpZZTnBmlllOcLvH5LYPsSrjv7MM8DU4M8ss9+DMLLOc4Mwss5zgykjSWEmLJC2WdHml47G2SbpT0hpJr1Q6Ftt1TnBlIqkO+DnwRWA4cI6k4ZWNyopwF1CxG1OttJzgymc0sDgi3oyIrcB9wLgKx2RtiIiZwPpKx2Gl4QRXPvsAy/P2VyRlZrabOMGVj1oo8z05ZruRE1z5rACG5O0PBlZWKBazdskJrnxmAcMk7S+pMzAemFbhmMzaFSe4MomIBuAi4HHgVWBqRCyobFTWFkn3As8Bh0haIem8SsdkO8+PaplZZrkHZ2aZ5QRnZpnlBGdmmeUEZ2aZ5QRnZpnlBFdDJDVKmifpFUm/kdRtF+q6S9JZyefbCy0EIOkESX+5E20slfSpty+1Vv6JYzambOsaSd9LG6NlmxNcbdkcESMj4jBgK/DN/C+TFUxSi4ivR8TCAoecAKROcGaV5gRXu54BDkp6V09J+jXwsqQ6ST+WNEvSfEkXACjnZkkLJf0eGNBckaSnJY1KPo+VNFfSS5JmSBpKLpFekvQej5PUX9L9SRuzJB2bnLuXpCck/Y+kX9Dy87g7kPSQpDmSFkg6/xPf3ZDEMkNS/6TsQEmPJec8I+nQkvxpWib5zfY1SFJHcuvMPZYUjQYOi4glSZJ4LyI+J6kL8EdJTwBHAIcAnwEGAguBOz9Rb3/gNuD4pK6+EbFe0q3Axoj4SXLcr4GbIuJZSfuSe1rjL4CrgWcj4lpJpwE7JKxWfC1pYw9glqT7I2Id0B2YGxHflXRVUvdF5F4G882IeEPSUcAk4KSd+GO0dsAJrrbsIWle8vkZ4A5yQ8cXI2JJUv4F4LPN19eAPYFhwPHAvRHRCKyU9IcW6j8amNlcV0S0ti7aKcBwaXsHrZeknkkbZybn/l7ShiJ+p4slnZF8HpLEug5oAv4zKb8beEBSj+T3/U1e212KaMPaKSe42rI5IkbmFyT/0DflFwH/JyIe/8RxX6Lt5ZpUxDGQu7RxTERsbiGWop/9k3QCuWR5TER8KOlpoGsrh0fS7ruf/DMwa42vwWXP48C3JHUCkHSwpO7ATGB8co2uHjixhXOfAz4vaf/k3L5J+QdAz7zjniA3XCQ5bmTycSZwblL2RaBPG7HuCWxIktuh5HqQzToAzb3Qvyc39H0fWCLp7KQNSTq8jTasHXOCy57byV1fm5u8OOUX5HrqDwJvAC8DtwD//ckTI+IdctfNHpD0Eh8PEX8LnNE8yQBcDIxKJjEW8vFs7r8Ax0uaS26ovKyNWB8DOkqaD1wHPJ/33SZghKQ55K6xXZuUnwucl8S3AC8DbwV4NREzyyz34Mwss5zgzCyznODMLLOc4Mwss5zgzCyznODMLLOc4Mwss/4/CgDWTh54qtgAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "print(\"Confusion Matrix:\")\n",
+ "cm = confusion_matrix(y_test, y_pred)\n",
+ "ConfusionMatrixDisplay(confusion_matrix=cm).plot();\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "05ea49ab",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# #optional extra\n",
+ "# # Checking to see how the data was scaled\n",
+ "\n",
+ "# # Visualize histograms before scaling for 'Fe' (example)\n",
+ "# plt.figure(figsize=(12, 6))\n",
+ "# plt.subplot(1, 2, 1)\n",
+ "# plt.hist(X['Fe'], bins=50, color='blue', alpha=0.7)\n",
+ "# plt.title('Before Scaling: Fe')\n",
+ "\n",
+ "# # Visualize histograms after scaling for 'Fe'\n",
+ "# plt.subplot(1, 2, 2)\n",
+ "# plt.hist(X_scaled[:, X.columns.get_loc('Fe')], bins=50, color='green', alpha=0.7)\n",
+ "# plt.title('After Scaling: Fe')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "705d2ffd",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# #optional extra\n",
+ "# Feature Importance\n",
+ "# feature_importances = best_rf.feature_importances_\n",
+ "# feature_importance_df = pd.DataFrame({'Element': elements, 'Importance': feature_importances})\n",
+ "# feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)\n",
+ "\n",
+ "# plt.figure(figsize=(10, 6))\n",
+ "# sns.barplot(x='Importance', y='Element', data=feature_importance_df, color='steelblue')\n",
+ "# plt.title('Feature Importance for Predicting Class A vs Class B')\n",
+ "# plt.tight_layout()\n",
+ "# plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "989eba46",
+ "metadata": {},
+ "source": [
+ "### Aim 2: Predict labels "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "41f2f5e1",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Predictions on Unlabeled Data:\n",
+ " Unique_ID holeid Predicted_Class\n",
+ "4004 A04920 SOLVE225W1 A\n",
+ "4005 A05729 SOLVE225W1 A\n",
+ "4006 A05270 SOLVE225W1 A\n",
+ "4007 A05634 SOLVE225W1 A\n",
+ "4008 A04689 SOLVE225W1 A\n",
+ "... ... ... ...\n",
+ "4072 A06469 SOLVE227 B\n",
+ "4073 A06256 SOLVE227 B\n",
+ "4074 A05766 SOLVE227 B\n",
+ "4075 A08273 SOLVE227 B\n",
+ "4076 A06155 SOLVE227 B\n",
+ "\n",
+ "[70 rows x 3 columns]\n"
+ ]
+ }
+ ],
+ "source": [
+ "unlabeled_data_copy = unlabeled_data.copy()\n",
+ "\n",
+ "# Predict unlabeled data \n",
+ "X_new = unlabeled_data[elements]\n",
+ "X_new_scaled = scaler.transform(X_new)\n",
+ "unlabeled_data.loc[:, 'Predicted_Class'] = le.inverse_transform(best_rf.predict(X_new_scaled))\n",
+ "\n",
+ "print(\"Predictions on Unlabeled Data:\")\n",
+ "print(unlabeled_data[['Unique_ID', 'holeid', 'Predicted_Class']].head(70))\n",
+ "unlabeled_data.to_csv('./data/predictions_on_unlabeled_data.csv', index=False)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23730f2b",
+ "metadata": {},
+ "source": [
+ "Results available on: predictions_on_unlabeled_data.csv"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "182b0d3e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Predictions for '?' class:\n",
+ "Predicted_Class\n",
+ "B 390\n",
+ "A 295\n",
+ "Name: count, dtype: int64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Filter the rows where the original 'Class' column is '?'\n",
+ "unlabeled_data_question = unlabeled_data[unlabeled_data['Class'] == '?']\n",
+ "\n",
+ "# Count how many were predicted to be 'A' and 'B'\n",
+ "class_counts_from_question = unlabeled_data_question['Predicted_Class'].value_counts()\n",
+ "\n",
+ "# Print out the prediction counts for 'A' and 'B' from '?'\n",
+ "print(f\"Predictions for '?' class:\")\n",
+ "print(class_counts_from_question)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0eb795e5",
+ "metadata": {},
+ "source": [
+ "### Observations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8bcc086",
+ "metadata": {},
+ "source": [
+ "- The model performs well overall, with high accuracy and strong performance for Class A.\n",
+ "- Class B predictions are less accurate.\n",
+ "- The imbalance in the dataset (more Class A samples) is probably why the model performs better for Class A \n",
+ "- Pb has the highest feature importance (makes sense as this is a lead deposit)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8dc4640",
+ "metadata": {},
+ "source": [
+ "### Improvements: \n",
+ "- Handle missing values more effectively (instead of dropping).\n",
+ "- Perform additional model validation (e.g., cross-validation, ROC curves).\n",
+ "- Experiment with other models (e.g., SVM, Gradient Boosting) for comparison."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}