From 9d23727eb0433b76771d517ae91db56b5f87b0d2 Mon Sep 17 00:00:00 2001 From: anastasiafilia <95204313+anastasiafilia@users.noreply.github.com> Date: Sun, 2 Feb 2025 23:16:08 +1100 Subject: [PATCH 1/3] Add files via upload --- Coding test - Anastasia .ipynb | 425 +++++++++++++++++++++++++++++++++ 1 file changed, 425 insertions(+) create mode 100644 Coding test - Anastasia .ipynb diff --git a/Coding test - Anastasia .ipynb b/Coding test - Anastasia .ipynb new file mode 100644 index 0000000..f3974e0 --- /dev/null +++ b/Coding test - Anastasia .ipynb @@ -0,0 +1,425 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "acb117e5", + "metadata": {}, + "source": [ + "# Ore Proximity Model" + ] + }, + { + "cell_type": "markdown", + "id": "6120f7f0", + "metadata": {}, + "source": [ + "A consultant is using metal abundance changes to predict proximity to an orebody. Samples are classified as proximal (A) or distal (B) based on Euclidean distance to a wireframe model of the orebody.\n", + "\n", + "1. Can we use the same geochemical data and labels to generate a predictive model for future drill holes which can label samples on whether they are in class A or class B?\n", + "2. More data has been acquired since the geochemist completed her work - can we predict labels onto these data points (labelled “?”).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b04ccd8c", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split, RandomizedSearchCV \n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n", + "from sklearn.preprocessing import StandardScaler, LabelEncoder \n", + "import seaborn as sns\n", + "from scipy.stats import randint" + ] + }, + { + "cell_type": "markdown", + "id": "9e051ac5", + "metadata": {}, + "source": [ + "## Step 1: Data - QA/QC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28e738e3", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "data = pd.read_csv('data/data_for_distribution.csv')\n", + "data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89a97680", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "data.info()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9701bd91", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Check how many NaN values are in each column\n", + "nan_values = data.isna().sum()\n", + "\n", + "# Print the number of NaN values per column\n", + "print(\"Number of NaN values per column:\")\n", + "print(nan_values)" + ] + }, + { + "cell_type": "markdown", + "id": "adf9369f", + "metadata": {}, + "source": [ + "Looking at the data, I saw that there are some there are some QA/QC issues.\n", + "These included\n", + "- Au data type is object\n", + "- missing values\n", + "- values below the detection limit (<0.005)\n", + "- unsuitable values (-999)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "943cad16", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "elements = ['As', 'Au', 'Pb', 'Fe', 'Mo', 'Cu', 'S', 'Zn']\n", + "\n", + "# Replace '<0.005' with half of the detection limit and '-999' with NaN\n", + "data[elements] = data[elements].replace({'<0.005': 0.005 / 2, '-999': np.nan})\n", + "\n", + "# Convert the 'Au' column to float64\n", + "data['Au'] = pd.to_numeric(data['Au'], errors='coerce')\n", + "\n", + "# Drop NaN values\n", + "data_cleaned = data.dropna(subset=elements)\n", + "\n", + "data_cleaned.to_csv(\"./data/cleaned_data.csv\", index=False)\n", + "data_cleaned.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c5c78d1", + "metadata": {}, + "outputs": [], + "source": [ + "# Some checks:\n", + "\n", + "# Check the data type of 'Au' column\n", + "print(data['Au'].dtype)\n", + "\n", + "nan_check = data_cleaned.isna().sum()\n", + "\n", + "# Print NaN values if any exist\n", + "if nan_check.any():\n", + " print(\"NaN values in cleaned data:\")\n", + " print(nan_check[nan_check > 0], \"\\n\")\n", + "\n", + "for element in elements:\n", + " # Check for values equal to '-999'\n", + " if (data_cleaned[element] == -999).any():\n", + " print(f\"Warning: '{element}' contains '-999' values.\")\n", + " \n", + " # Check for values equal to '<0.005'\n", + " if (data_cleaned[element] == '<0.005').any():\n", + " print(f\"Warning: '{element}' contains '<0.005' values.\")" + ] + }, + { + "cell_type": "markdown", + "id": "53c5522b", + "metadata": {}, + "source": [ + "We can do other data checks but I am assuming these are the only issues and moving on" + ] + }, + { + "cell_type": "markdown", + "id": "2d2aef0f", + "metadata": {}, + "source": [ + "## Step 2: Modelling\n", + "### Aim 1: Generate a predictive model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "411d810e", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Check the distribution of classes\n", + "class_counts = data_cleaned['Class'].value_counts()\n", + "print(class_counts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e945df09", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Mostly from https://www.datacamp.com/tutorial/random-forests-classifier-python\n", + "\n", + "# Separate labeled and unlabeled data\n", + "labeled_data = data_cleaned[data_cleaned['Class'].isin(['A', 'B'])]\n", + "unlabeled_data = data_cleaned[data_cleaned['Class'] == '?']\n", + "\n", + "# Features and target\n", + "X = labeled_data[elements]\n", + "y = labeled_data['Class']\n", + "\n", + "# Encode labels\n", + "le = LabelEncoder()\n", + "y_encoded = le.fit_transform(y)\n", + "\n", + "# Standardize features\n", + "scaler = StandardScaler()\n", + "X_scaled = scaler.fit_transform(X)\n", + "\n", + "# Train-test split\n", + "X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)\n", + "\n", + "\n", + "#Hyperparameter tuning \n", + "param_dist = {'n_estimators': randint(50,500),\n", + " 'max_depth': randint(1,20)}\n", + "\n", + "rf = RandomForestClassifier()\n", + "\n", + "rand_search = RandomizedSearchCV(rf, \n", + " param_distributions = param_dist, \n", + " n_iter=5, \n", + " cv=5)\n", + "rand_search.fit(X_train, y_train)\n", + "\n", + "best_rf = rand_search.best_estimator_\n", + "\n", + "# Print the best hyperparameters\n", + "print('Best hyperparameters:', rand_search.best_params_)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27ad8b29", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Evaluate model\n", + "y_pred = best_rf.predict(X_test)\n", + "print(\"Classification Report:\")\n", + "print(classification_report(y_test, y_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53d68eca", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Confusion Matrix:\")\n", + "cm = confusion_matrix(y_test, y_pred)\n", + "ConfusionMatrixDisplay(confusion_matrix=cm).plot();\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05ea49ab", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# #optional extra\n", + "# # Checking to see how the data was scaled\n", + "\n", + "# # Visualize histograms before scaling for 'Fe' (example)\n", + "# plt.figure(figsize=(12, 6))\n", + "# plt.subplot(1, 2, 1)\n", + "# plt.hist(X['Fe'], bins=50, color='blue', alpha=0.7)\n", + "# plt.title('Before Scaling: Fe')\n", + "\n", + "# # Visualize histograms after scaling for 'Fe'\n", + "# plt.subplot(1, 2, 2)\n", + "# plt.hist(X_scaled[:, X.columns.get_loc('Fe')], bins=50, color='green', alpha=0.7)\n", + "# plt.title('After Scaling: Fe')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "705d2ffd", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# #optional extra\n", + "# Feature Importance\n", + "# feature_importances = best_rf.feature_importances_\n", + "# feature_importance_df = pd.DataFrame({'Element': elements, 'Importance': feature_importances})\n", + "# feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)\n", + "\n", + "# plt.figure(figsize=(10, 6))\n", + "# sns.barplot(x='Importance', y='Element', data=feature_importance_df, color='steelblue')\n", + "# plt.title('Feature Importance for Predicting Class A vs Class B')\n", + "# plt.tight_layout()\n", + "# plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "989eba46", + "metadata": {}, + "source": [ + "### Aim 2: Predict labels " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41f2f5e1", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "unlabeled_data_copy = unlabeled_data.copy()\n", + "\n", + "# Predict unlabeled data \n", + "X_new = unlabeled_data[elements]\n", + "X_new_scaled = scaler.transform(X_new)\n", + "unlabeled_data.loc[:, 'Predicted_Class'] = le.inverse_transform(best_rf.predict(X_new_scaled))\n", + "\n", + "print(\"Predictions on Unlabeled Data:\")\n", + "print(unlabeled_data[['Unique_ID', 'holeid', 'Predicted_Class']].head(70))\n", + "unlabeled_data.to_csv('./data/predictions_on_unlabeled_data.csv', index=False)\n" + ] + }, + { + "cell_type": "markdown", + "id": "23730f2b", + "metadata": {}, + "source": [ + "Results available on: predictions_on_unlabeled_data.csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "182b0d3e", + "metadata": {}, + "outputs": [], + "source": [ + "# Filter the rows where the original 'Class' column is '?'\n", + "unlabeled_data_question = unlabeled_data[unlabeled_data['Class'] == '?']\n", + "\n", + "# Count how many were predicted to be 'A' and 'B'\n", + "class_counts_from_question = unlabeled_data_question['Predicted_Class'].value_counts()\n", + "\n", + "# Print out the prediction counts for 'A' and 'B' from '?'\n", + "print(f\"Predictions for '?' class:\")\n", + "print(class_counts_from_question)" + ] + }, + { + "cell_type": "markdown", + "id": "0eb795e5", + "metadata": {}, + "source": [ + "### Observations" + ] + }, + { + "cell_type": "markdown", + "id": "e8bcc086", + "metadata": {}, + "source": [ + "- The model performs well overall, with high accuracy and strong performance for Class A.\n", + "- Class B predictions are less accurate.\n", + "- The imbalance in the dataset (more Class A samples) is probably why the model performs better for Class A \n", + "- Pb has the highest feature importance (makes sense as this is a lead deposit)" + ] + }, + { + "cell_type": "markdown", + "id": "e8dc4640", + "metadata": {}, + "source": [ + "### Improvements: \n", + "- Handle missing values more effectively (instead of dropping).\n", + "- Perform additional model validation (e.g., cross-validation, ROC curves).\n", + "- Experiment with other models (e.g., SVM, Gradient Boosting) for comparison." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cabd73e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 283af0b2c68df9843d0dbf266e06d86186ed55ca Mon Sep 17 00:00:00 2001 From: anastasiafilia <95204313+anastasiafilia@users.noreply.github.com> Date: Sun, 2 Feb 2025 23:32:03 +1100 Subject: [PATCH 2/3] Delete Coding test - Anastasia .ipynb --- Coding test - Anastasia .ipynb | 425 --------------------------------- 1 file changed, 425 deletions(-) delete mode 100644 Coding test - Anastasia .ipynb diff --git a/Coding test - Anastasia .ipynb b/Coding test - Anastasia .ipynb deleted file mode 100644 index f3974e0..0000000 --- a/Coding test - Anastasia .ipynb +++ /dev/null @@ -1,425 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "acb117e5", - "metadata": {}, - "source": [ - "# Ore Proximity Model" - ] - }, - { - "cell_type": "markdown", - "id": "6120f7f0", - "metadata": {}, - "source": [ - "A consultant is using metal abundance changes to predict proximity to an orebody. Samples are classified as proximal (A) or distal (B) based on Euclidean distance to a wireframe model of the orebody.\n", - "\n", - "1. Can we use the same geochemical data and labels to generate a predictive model for future drill holes which can label samples on whether they are in class A or class B?\n", - "2. More data has been acquired since the geochemist completed her work - can we predict labels onto these data points (labelled “?”).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b04ccd8c", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.model_selection import train_test_split, RandomizedSearchCV \n", - "from sklearn.ensemble import RandomForestClassifier\n", - "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n", - "from sklearn.preprocessing import StandardScaler, LabelEncoder \n", - "import seaborn as sns\n", - "from scipy.stats import randint" - ] - }, - { - "cell_type": "markdown", - "id": "9e051ac5", - "metadata": {}, - "source": [ - "## Step 1: Data - QA/QC" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28e738e3", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "data = pd.read_csv('data/data_for_distribution.csv')\n", - "data.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "89a97680", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "data.info()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9701bd91", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "# Check how many NaN values are in each column\n", - "nan_values = data.isna().sum()\n", - "\n", - "# Print the number of NaN values per column\n", - "print(\"Number of NaN values per column:\")\n", - "print(nan_values)" - ] - }, - { - "cell_type": "markdown", - "id": "adf9369f", - "metadata": {}, - "source": [ - "Looking at the data, I saw that there are some there are some QA/QC issues.\n", - "These included\n", - "- Au data type is object\n", - "- missing values\n", - "- values below the detection limit (<0.005)\n", - "- unsuitable values (-999)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "943cad16", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "elements = ['As', 'Au', 'Pb', 'Fe', 'Mo', 'Cu', 'S', 'Zn']\n", - "\n", - "# Replace '<0.005' with half of the detection limit and '-999' with NaN\n", - "data[elements] = data[elements].replace({'<0.005': 0.005 / 2, '-999': np.nan})\n", - "\n", - "# Convert the 'Au' column to float64\n", - "data['Au'] = pd.to_numeric(data['Au'], errors='coerce')\n", - "\n", - "# Drop NaN values\n", - "data_cleaned = data.dropna(subset=elements)\n", - "\n", - "data_cleaned.to_csv(\"./data/cleaned_data.csv\", index=False)\n", - "data_cleaned.head()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6c5c78d1", - "metadata": {}, - "outputs": [], - "source": [ - "# Some checks:\n", - "\n", - "# Check the data type of 'Au' column\n", - "print(data['Au'].dtype)\n", - "\n", - "nan_check = data_cleaned.isna().sum()\n", - "\n", - "# Print NaN values if any exist\n", - "if nan_check.any():\n", - " print(\"NaN values in cleaned data:\")\n", - " print(nan_check[nan_check > 0], \"\\n\")\n", - "\n", - "for element in elements:\n", - " # Check for values equal to '-999'\n", - " if (data_cleaned[element] == -999).any():\n", - " print(f\"Warning: '{element}' contains '-999' values.\")\n", - " \n", - " # Check for values equal to '<0.005'\n", - " if (data_cleaned[element] == '<0.005').any():\n", - " print(f\"Warning: '{element}' contains '<0.005' values.\")" - ] - }, - { - "cell_type": "markdown", - "id": "53c5522b", - "metadata": {}, - "source": [ - "We can do other data checks but I am assuming these are the only issues and moving on" - ] - }, - { - "cell_type": "markdown", - "id": "2d2aef0f", - "metadata": {}, - "source": [ - "## Step 2: Modelling\n", - "### Aim 1: Generate a predictive model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "411d810e", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Check the distribution of classes\n", - "class_counts = data_cleaned['Class'].value_counts()\n", - "print(class_counts)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e945df09", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Mostly from https://www.datacamp.com/tutorial/random-forests-classifier-python\n", - "\n", - "# Separate labeled and unlabeled data\n", - "labeled_data = data_cleaned[data_cleaned['Class'].isin(['A', 'B'])]\n", - "unlabeled_data = data_cleaned[data_cleaned['Class'] == '?']\n", - "\n", - "# Features and target\n", - "X = labeled_data[elements]\n", - "y = labeled_data['Class']\n", - "\n", - "# Encode labels\n", - "le = LabelEncoder()\n", - "y_encoded = le.fit_transform(y)\n", - "\n", - "# Standardize features\n", - "scaler = StandardScaler()\n", - "X_scaled = scaler.fit_transform(X)\n", - "\n", - "# Train-test split\n", - "X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)\n", - "\n", - "\n", - "#Hyperparameter tuning \n", - "param_dist = {'n_estimators': randint(50,500),\n", - " 'max_depth': randint(1,20)}\n", - "\n", - "rf = RandomForestClassifier()\n", - "\n", - "rand_search = RandomizedSearchCV(rf, \n", - " param_distributions = param_dist, \n", - " n_iter=5, \n", - " cv=5)\n", - "rand_search.fit(X_train, y_train)\n", - "\n", - "best_rf = rand_search.best_estimator_\n", - "\n", - "# Print the best hyperparameters\n", - "print('Best hyperparameters:', rand_search.best_params_)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27ad8b29", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "# Evaluate model\n", - "y_pred = best_rf.predict(X_test)\n", - "print(\"Classification Report:\")\n", - "print(classification_report(y_test, y_pred))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53d68eca", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Confusion Matrix:\")\n", - "cm = confusion_matrix(y_test, y_pred)\n", - "ConfusionMatrixDisplay(confusion_matrix=cm).plot();\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05ea49ab", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# #optional extra\n", - "# # Checking to see how the data was scaled\n", - "\n", - "# # Visualize histograms before scaling for 'Fe' (example)\n", - "# plt.figure(figsize=(12, 6))\n", - "# plt.subplot(1, 2, 1)\n", - "# plt.hist(X['Fe'], bins=50, color='blue', alpha=0.7)\n", - "# plt.title('Before Scaling: Fe')\n", - "\n", - "# # Visualize histograms after scaling for 'Fe'\n", - "# plt.subplot(1, 2, 2)\n", - "# plt.hist(X_scaled[:, X.columns.get_loc('Fe')], bins=50, color='green', alpha=0.7)\n", - "# plt.title('After Scaling: Fe')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "705d2ffd", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# #optional extra\n", - "# Feature Importance\n", - "# feature_importances = best_rf.feature_importances_\n", - "# feature_importance_df = pd.DataFrame({'Element': elements, 'Importance': feature_importances})\n", - "# feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)\n", - "\n", - "# plt.figure(figsize=(10, 6))\n", - "# sns.barplot(x='Importance', y='Element', data=feature_importance_df, color='steelblue')\n", - "# plt.title('Feature Importance for Predicting Class A vs Class B')\n", - "# plt.tight_layout()\n", - "# plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "989eba46", - "metadata": {}, - "source": [ - "### Aim 2: Predict labels " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41f2f5e1", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "unlabeled_data_copy = unlabeled_data.copy()\n", - "\n", - "# Predict unlabeled data \n", - "X_new = unlabeled_data[elements]\n", - "X_new_scaled = scaler.transform(X_new)\n", - "unlabeled_data.loc[:, 'Predicted_Class'] = le.inverse_transform(best_rf.predict(X_new_scaled))\n", - "\n", - "print(\"Predictions on Unlabeled Data:\")\n", - "print(unlabeled_data[['Unique_ID', 'holeid', 'Predicted_Class']].head(70))\n", - "unlabeled_data.to_csv('./data/predictions_on_unlabeled_data.csv', index=False)\n" - ] - }, - { - "cell_type": "markdown", - "id": "23730f2b", - "metadata": {}, - "source": [ - "Results available on: predictions_on_unlabeled_data.csv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "182b0d3e", - "metadata": {}, - "outputs": [], - "source": [ - "# Filter the rows where the original 'Class' column is '?'\n", - "unlabeled_data_question = unlabeled_data[unlabeled_data['Class'] == '?']\n", - "\n", - "# Count how many were predicted to be 'A' and 'B'\n", - "class_counts_from_question = unlabeled_data_question['Predicted_Class'].value_counts()\n", - "\n", - "# Print out the prediction counts for 'A' and 'B' from '?'\n", - "print(f\"Predictions for '?' class:\")\n", - "print(class_counts_from_question)" - ] - }, - { - "cell_type": "markdown", - "id": "0eb795e5", - "metadata": {}, - "source": [ - "### Observations" - ] - }, - { - "cell_type": "markdown", - "id": "e8bcc086", - "metadata": {}, - "source": [ - "- The model performs well overall, with high accuracy and strong performance for Class A.\n", - "- Class B predictions are less accurate.\n", - "- The imbalance in the dataset (more Class A samples) is probably why the model performs better for Class A \n", - "- Pb has the highest feature importance (makes sense as this is a lead deposit)" - ] - }, - { - "cell_type": "markdown", - "id": "e8dc4640", - "metadata": {}, - "source": [ - "### Improvements: \n", - "- Handle missing values more effectively (instead of dropping).\n", - "- Perform additional model validation (e.g., cross-validation, ROC curves).\n", - "- Experiment with other models (e.g., SVM, Gradient Boosting) for comparison." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7cabd73e", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 60cba53bc93b816170edec790bdacbdd58d94658 Mon Sep 17 00:00:00 2001 From: anastasiafilia <95204313+anastasiafilia@users.noreply.github.com> Date: Sun, 2 Feb 2025 23:32:20 +1100 Subject: [PATCH 3/3] Add files via upload --- Coding test - Anastasia.ipynb | 849 ++++++++++++++++++++++++++++++++++ 1 file changed, 849 insertions(+) create mode 100644 Coding test - Anastasia.ipynb diff --git a/Coding test - Anastasia.ipynb b/Coding test - Anastasia.ipynb new file mode 100644 index 0000000..c1f5dce --- /dev/null +++ b/Coding test - Anastasia.ipynb @@ -0,0 +1,849 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "acb117e5", + "metadata": {}, + "source": [ + "# Ore Proximity Model" + ] + }, + { + "cell_type": "markdown", + "id": "6120f7f0", + "metadata": {}, + "source": [ + "A consultant is using metal abundance changes to predict proximity to an orebody. Samples are classified as proximal (A) or distal (B) based on Euclidean distance to a wireframe model of the orebody.\n", + "\n", + "1. Can we use the same geochemical data and labels to generate a predictive model for future drill holes which can label samples on whether they are in class A or class B?\n", + "2. More data has been acquired since the geochemist completed her work - can we predict labels onto these data points (labelled “?”).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b04ccd8c", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split, RandomizedSearchCV \n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n", + "from sklearn.preprocessing import StandardScaler, LabelEncoder \n", + "import seaborn as sns\n", + "from scipy.stats import randint" + ] + }, + { + "cell_type": "markdown", + "id": "9e051ac5", + "metadata": {}, + "source": [ + "## Step 1: Data - QA/QC" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "28e738e3", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Unique_IDholeidfromtoAsAuPbFeMoCuSZnClass
0A04812SOLVE003561571.0NaN0.0661031.0061380.0138.20003.6003586.000043.6000A
1A03356SOLVE003571581.0NaN0.1521982.0050860.075.40004.8001822.000036.4000A
2A04764SOLVE003581591.0NaN0.0681064.8057940.029.20003.000740.400036.6000A
3A04626SOLVE003591601.0NaN0.074891.6048620.063.00004.200820.800039.6000A
4A05579SOLVE003601611.0NaN0.043125801.2551025.056.06254.875745.687532.3125A
\n", + "
" + ], + "text/plain": [ + " Unique_ID holeid from to As Au Pb Fe Mo \\\n", + "0 A04812 SOLVE003 561 571.0 NaN 0.066 1031.00 61380.0 138.2000 \n", + "1 A03356 SOLVE003 571 581.0 NaN 0.152 1982.00 50860.0 75.4000 \n", + "2 A04764 SOLVE003 581 591.0 NaN 0.068 1064.80 57940.0 29.2000 \n", + "3 A04626 SOLVE003 591 601.0 NaN 0.074 891.60 48620.0 63.0000 \n", + "4 A05579 SOLVE003 601 611.0 NaN 0.043125 801.25 51025.0 56.0625 \n", + "\n", + " Cu S Zn Class \n", + "0 3.600 3586.0000 43.6000 A \n", + "1 4.800 1822.0000 36.4000 A \n", + "2 3.000 740.4000 36.6000 A \n", + "3 4.200 820.8000 39.6000 A \n", + "4 4.875 745.6875 32.3125 A " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = pd.read_csv('data/data_for_distribution.csv')\n", + "data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "89a97680", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 4771 entries, 0 to 4770\n", + "Data columns (total 13 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 Unique_ID 4771 non-null object \n", + " 1 holeid 4771 non-null object \n", + " 2 from 4771 non-null int64 \n", + " 3 to 4771 non-null float64\n", + " 4 As 3268 non-null float64\n", + " 5 Au 4765 non-null object \n", + " 6 Pb 4756 non-null float64\n", + " 7 Fe 4709 non-null float64\n", + " 8 Mo 4741 non-null float64\n", + " 9 Cu 4746 non-null float64\n", + " 10 S 4761 non-null float64\n", + " 11 Zn 4762 non-null float64\n", + " 12 Class 4771 non-null object \n", + "dtypes: float64(8), int64(1), object(4)\n", + "memory usage: 484.7+ KB\n" + ] + } + ], + "source": [ + "data.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9701bd91", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of NaN values per column:\n", + "Unique_ID 0\n", + "holeid 0\n", + "from 0\n", + "to 0\n", + "As 1503\n", + "Au 6\n", + "Pb 15\n", + "Fe 62\n", + "Mo 30\n", + "Cu 25\n", + "S 10\n", + "Zn 9\n", + "Class 0\n", + "dtype: int64\n" + ] + } + ], + "source": [ + "# Check how many NaN values are in each column\n", + "nan_values = data.isna().sum()\n", + "\n", + "# Print the number of NaN values per column\n", + "print(\"Number of NaN values per column:\")\n", + "print(nan_values)" + ] + }, + { + "cell_type": "markdown", + "id": "adf9369f", + "metadata": {}, + "source": [ + "Looking at the data, I saw that there are some there are some QA/QC issues.\n", + "These included\n", + "- Au data type is object\n", + "- missing values\n", + "- values below the detection limit (<0.005)\n", + "- unsuitable values (-999)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "943cad16", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Unique_IDholeidfromtoAsAuPbFeMoCuSZnClass
257A03720SOLVE026731741.05.00.1221881.235440.0133.06.02770.043.4A
1043A02183SOLVE047351361.018.40.2902150.065700.032.6146.49854.0362.0A
1057A04574SOLVE052411421.06.60.0761636.057380.057.412.21808.037.8A
1058A04149SOLVE052421431.015.80.0972375.653980.050.413.22878.035.6A
1059A05461SOLVE052431441.018.60.0461514.654920.018.613.82394.030.4A
\n", + "
" + ], + "text/plain": [ + " Unique_ID holeid from to As Au Pb Fe Mo \\\n", + "257 A03720 SOLVE026 731 741.0 5.0 0.122 1881.2 35440.0 133.0 \n", + "1043 A02183 SOLVE047 351 361.0 18.4 0.290 2150.0 65700.0 32.6 \n", + "1057 A04574 SOLVE052 411 421.0 6.6 0.076 1636.0 57380.0 57.4 \n", + "1058 A04149 SOLVE052 421 431.0 15.8 0.097 2375.6 53980.0 50.4 \n", + "1059 A05461 SOLVE052 431 441.0 18.6 0.046 1514.6 54920.0 18.6 \n", + "\n", + " Cu S Zn Class \n", + "257 6.0 2770.0 43.4 A \n", + "1043 146.4 9854.0 362.0 A \n", + "1057 12.2 1808.0 37.8 A \n", + "1058 13.2 2878.0 35.6 A \n", + "1059 13.8 2394.0 30.4 A " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "elements = ['As', 'Au', 'Pb', 'Fe', 'Mo', 'Cu', 'S', 'Zn']\n", + "\n", + "# Replace '<0.005' with half of the detection limit and '-999' with NaN\n", + "data[elements] = data[elements].replace({'<0.005': 0.005 / 2, '-999': np.nan})\n", + "\n", + "# Convert the 'Au' column to float64\n", + "data['Au'] = pd.to_numeric(data['Au'], errors='coerce')\n", + "\n", + "# Drop NaN values\n", + "data_cleaned = data.dropna(subset=elements)\n", + "\n", + "data_cleaned.to_csv(\"./data/cleaned_data.csv\", index=False)\n", + "data_cleaned.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6c5c78d1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float64\n" + ] + } + ], + "source": [ + "# Some checks:\n", + "\n", + "# Check the data type of 'Au' column\n", + "print(data['Au'].dtype)\n", + "\n", + "nan_check = data_cleaned.isna().sum()\n", + "\n", + "# Print NaN values if any exist\n", + "if nan_check.any():\n", + " print(\"NaN values in cleaned data:\")\n", + " print(nan_check[nan_check > 0], \"\\n\")\n", + "\n", + "for element in elements:\n", + " # Check for values equal to '-999'\n", + " if (data_cleaned[element] == -999).any():\n", + " print(f\"Warning: '{element}' contains '-999' values.\")\n", + " \n", + " # Check for values equal to '<0.005'\n", + " if (data_cleaned[element] == '<0.005').any():\n", + " print(f\"Warning: '{element}' contains '<0.005' values.\")" + ] + }, + { + "cell_type": "markdown", + "id": "53c5522b", + "metadata": {}, + "source": [ + "We can do other data checks but I am assuming these are the only issues and moving on" + ] + }, + { + "cell_type": "markdown", + "id": "2d2aef0f", + "metadata": {}, + "source": [ + "## Step 2: Modelling\n", + "### Aim 1: Generate a predictive model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "411d810e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class\n", + "A 1759\n", + "B 743\n", + "? 685\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "# Check the distribution of classes\n", + "class_counts = data_cleaned['Class'].value_counts()\n", + "print(class_counts)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e945df09", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best hyperparameters: {'max_depth': 17, 'n_estimators': 466}\n" + ] + } + ], + "source": [ + "# Mostly from https://www.datacamp.com/tutorial/random-forests-classifier-python\n", + "\n", + "# Separate labeled and unlabeled data\n", + "labeled_data = data_cleaned[data_cleaned['Class'].isin(['A', 'B'])]\n", + "unlabeled_data = data_cleaned[data_cleaned['Class'] == '?']\n", + "\n", + "# Features and target\n", + "X = labeled_data[elements]\n", + "y = labeled_data['Class']\n", + "\n", + "# Encode labels\n", + "le = LabelEncoder()\n", + "y_encoded = le.fit_transform(y)\n", + "\n", + "# Standardize features\n", + "scaler = StandardScaler()\n", + "X_scaled = scaler.fit_transform(X)\n", + "\n", + "# Train-test split\n", + "X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)\n", + "\n", + "\n", + "#Hyperparameter tuning \n", + "param_dist = {'n_estimators': randint(50,500),\n", + " 'max_depth': randint(1,20)}\n", + "\n", + "rf = RandomForestClassifier()\n", + "\n", + "rand_search = RandomizedSearchCV(rf, \n", + " param_distributions = param_dist, \n", + " n_iter=5, \n", + " cv=5)\n", + "rand_search.fit(X_train, y_train)\n", + "\n", + "best_rf = rand_search.best_estimator_\n", + "\n", + "# Print the best hyperparameters\n", + "print('Best hyperparameters:', rand_search.best_params_)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "27ad8b29", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " 0 0.93 0.92 0.92 363\n", + " 1 0.80 0.80 0.80 138\n", + "\n", + " accuracy 0.89 501\n", + " macro avg 0.86 0.86 0.86 501\n", + "weighted avg 0.89 0.89 0.89 501\n", + "\n" + ] + } + ], + "source": [ + "# Evaluate model\n", + "y_pred = best_rf.predict(X_test)\n", + "print(\"Classification Report:\")\n", + "print(classification_report(y_test, y_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "53d68eca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion Matrix:\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATgAAAEGCAYAAADxD4m3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaJ0lEQVR4nO3deZRV1Zn38e+PYpJJQAYLQXFAbTARfQlq2xrHQGJ3o7ZmYdtp3sREk+hrWhPTmhWH1kW3nURN3tegwSGS1miTdiKJExJtNMuBoREFRYkQQBBkcAARqKrn/eOewotW3boH7uXee+r3Weusumffc/Z+CuRx77PP2UcRgZlZFnWodABmZuXiBGdmmeUEZ2aZ5QRnZpnlBGdmmdWx0gHk69e3LoYO6VTpMCyF1+d3q3QIlsJHbGJrbNGu1DHmxO6xbn1jUcfOmb/l8YgYuyvt7YqqSnBDh3TixceHVDoMS2HMoJGVDsFSeCFm7HIda9c38sLjg4s6tlP9n/rtcoO7oKoSnJnVgqAxmiodRFGc4MwslQCaqI0HBJzgzCy1JtyDM7MMCoJtHqKaWRYF0Oghqpllla/BmVkmBdBYI6sQOcGZWWq1cQXOCc7MUgrC1+DMLJsiYFtt5DcnODNLSzSyS4+z7jZOcGaWSgBN7sGZWVa5B2dmmZS70dcJzswyKIBtURtr5TrBmVkqgWiskcXAneDMLLWm8BDVzDLI1+DMLMNEo6/BmVkW5Vb0rY0EVxtRmlnViBBbo66orRBJXSW9KOklSQsk/UtS3lfSdElvJD/75J1zhaTFkhZJGtNWrE5wZpZaEypqa8MW4KSIOBwYCYyVdDRwOTAjIoYBM5J9JA0HxgMjgLHAJEkFs6gTnJmlkptk6FDUVrCenI3JbqdkC2AcMCUpnwKcnnweB9wXEVsiYgmwGBhdqA0nODNLKTfJUMwG9JM0O287f4eapDpJ84A1wPSIeAEYGBGrAJKfA5LD9wGW552+IilrlScZzCyVlJMMayNiVKt1RTQCIyX1Bh6UdFiBuloa8xZ87N8JzsxSayzxjb4R8a6kp8ldW1stqT4iVkmqJ9e7g1yPbUjeaYOBlYXq9RDVzFIJxLboWNRWiKT+Sc8NSXsApwCvAdOACclhE4CHk8/TgPGSukjaHxgGvFioDffgzCyV5kmGEqgHpiQzoR2AqRHxO0nPAVMlnQcsA84GiIgFkqYCC4EG4MJkiNsqJzgzSyVQSYaoETEfOKKF8nXAya2cMxGYWGwbTnBmllqtPMngBGdmqUTgZ1HNLJtykwyFH8OqFk5wZpaaF7w0s0wK5AUvzSy73IMzs0zKvRfVCc7MMslvtjezjMq9NtCzqGaWQRHyENXMsss3+ppZJuXWg/M1ODPLJL820MwyKnebiHtwZpZBfhbVzDLNyyWZWSbllkvyENXMMsrX4Mwsk3KriXiIamYZlHtUywmuXdj6kfjumQexbWsHGhvguNPe4x8ve5spP9qb5x7fEwl699vG9366jL32buDt5Z35xucPZfABWwA49H9t4jv/vqLCv0X71X/QVi772TL6DGggmuCRu/fioTv6c8CIzVx8/Qo6d22isUHcfMVgFs3rVulwq4R7cABIGgv8DKgDbo+I68vZXiV06hL86Dd/Yo/uTTRsg0tPH8bnTnqfs761hgnffxuAh27vx9037b09kdXvt4VbnlxUybAt0dggJl87iMUvd2OP7o3c/NjrzJ3Zk6//cCV33ziQ2U/14nMnvc95P1zJ9886qNLhVo12/yRD8q7DnwOnknsj9SxJ0yJiYbnarAQJ9ujeBEDDNtG4TUjQvWfT9mM+2twB1cZ/D+3O+jWdWL+mEwCbN9WxfHFX+tVvIwK698y9crN7r0bWr+5UyTCrimdRc0YDiyPiTQBJ9wHjyL20NVMaG+GiMYewcmln/uZ/r+XQIz8E4JfX782Tv+lL916N/Oi/Fm8//u1lnfn2qQfTrWcTE/55FZ85alOlQrc8Awdv5cDDNvPa3G7cetU+/Ou9b/KNq1YhBZf87bBKh1dVamWIWs4o9wGW5+2vSMp2IOl8SbMlzX5nXcGXVFetujq45clF3DNnIYvmdWPpa10B+Orlb3PPnIWcdOYGpt3ZH4C+A7Zx96yFTJr+Ohdc8xbXf3s/Nn1QG/+xZFnXbo1ceftSbr1qEB9urOOvJ6zjF1cP4h9GDecX1+zDpTcub7uSdqL5nQzFbJVWzn9ZLf128amCiMkRMSoiRvXfqzYe/2hNjz0bOfyYjcx6qucO5SeesYFnH9kTgM5dgl59c4l82Gc3M2joVt56s8tuj9U+VtcxuPL2pfzhgT788dHeAJx69vrtf2czf7snB4/8sIIRVpcAGqJDUVullTOCFcCQvP3BwMoytlcR766rY+N7ucS8ZbOY+0xPhhy0hbfe7Lz9mOcf35MhB23Zfnxj0lFd9efOvLWkM3vvu3W3x23NgktvWM7yN7rywOT+20vXre7EZ4/JXToY+VcbWbnE/xPK1xQditoKkTRE0lOSXpW0QNJ3kvJrJL0laV6yfSnvnCskLZa0SNKYtuIs5zW4WcAwSfsDbwHjgb8vY3sVsX51J37ynX1pahJNTXD837zL0ae+z7VfH8qKP3WhQwcYsM9WLk5mUF9+vge/+vHe1HWEug7BxdevoFef2hyaZ8GI0Zs45ewNvLmwK5Om52a2f/lv9fz0ssF869qV1NUFW7d04KeXDa5wpFWkdMPPBuC7ETFXUk9gjqTpyXc3RcRP8g+WNJxcHhkBDAKelHRwRLT6D6hsCS4iGiRdBDxO7jaROyNiQbnaq5QDhn/EpOmvf6r8qtuXtnj8cae9x3GnvVfmqKxYC17swZhBh7f43UVjD97N0dSGUi14GRGrgFXJ5w8kvUoL1+nzjAPui4gtwBJJi8lNZj7X2gllHSRHxCMRcXBEHBgRE8vZlpntPikmGfo1TyIm2/kt1SdpKHAE8EJSdJGk+ZLulNQnKStq4jKfn2Qws1RSLni5NiJGFTpAUg/gfuCfIuJ9SbcA1yVNXQfcAHyNIicu8znBmVkqgWhoKs3gT1Incsntnoh4ACAiVud9fxvwu2Q39cRl5edxzazmNKGitkIkCbgDeDUibswrr8877AzgleTzNGC8pC7J5OUw4MVCbbgHZ2bpRMnWgzsW+ArwsqR5SdkPgHMkjcy1xFLgAoCIWCBpKrmnoRqACwvNoIITnJmlVKqXzkTEs7R8Xe2RAudMBIqesHSCM7PUquExrGI4wZlZKoFoLNEkQ7k5wZlZau1+PTgzy6Yo3SRD2TnBmVlq4QRnZtlUHWu9FcMJzsxScw/OzDIpAhqbnODMLKM8i2pmmRR4iGpmmeVJBjPLsCi4Clv1cIIzs9Q8RDWzTMrNovpZVDPLKA9RzSyzPEQ1s0wK5ARnZtlVIyNUJzgzSykg/KiWmWWVh6hmllk1P4sq6f9RYKgdEReXJSIzq2pZeRZ19m6LwsxqRwC1nuAiYkr+vqTuEbGp/CGZWbWrlSFqm89bSDpG0kLg1WT/cEmTyh6ZmVUpEU3FbZVWzANlPwXGAOsAIuIl4PgyxmRm1S6K3CqsqFnUiFgu7ZCNG8sTjplVvaidSYZienDLJf0lEJI6S/oeyXDVzNqpEvTgJA2R9JSkVyUtkPSdpLyvpOmS3kh+9sk75wpJiyUtkjSmrTCLSXDfBC4E9gHeAkYm+2bWbqnIraAG4LsR8RfA0cCFkoYDlwMzImIYMCPZJ/luPDACGAtMklRXqIE2h6gRsRY4t63jzKwdadr1KiJiFbAq+fyBpFfJdaTGASckh00Bngb+OSm/LyK2AEskLQZGA8+11kYxs6gHSPqtpHckrZH0sKQDdv7XMrOa1nwfXDEb9JM0O287v6UqJQ0FjgBeAAYmya85CQ5IDtsHWJ532oqkrFXFTDL8Gvg5cEayPx64FziqiHPNLINS3Ae3NiJGFTpAUg/gfuCfIuL9T0xo7nBoS6EUqruYa3CKiP+IiIZku7utSs0s40p0m4ikTuSS2z0R8UBSvFpSffJ9PbAmKV8BDMk7fTCwslD9rSa4ZCajL/CUpMslDZW0n6TvA79vO3Qzy6zih6itUq6rdgfwakTcmPfVNGBC8nkC8HBe+XhJXSTtDwwDXizURqEh6hxyObg5ygvyfz3guoLRm1lmqTRjuGOBrwAvS5qXlP0AuB6YKuk8YBlwNkBELJA0FVhIbgb2wogoeE9uoWdR99/l8M0se0JQgsewIuJZWr+X5ORWzpkITCy2jaKeZJB0GDAc6JrX0K+KbcTMMqZGrsK3meAkXU3unpThwCPAF4FnASc4s/aqRhJcMbOoZ5HrLr4dEV8FDge6lDUqM6tuGXrYfnNENElqkNSL3JStb/Q1a6+ysOBlntmSegO3kZtZ3UgbU7Nmlm0lmkUtu2KeRf128vFWSY8BvSJifnnDMrOqVusJTtKRhb6LiLnlCcnMql0WenA3FPgugJNKHAuvz+/GmEEjS12tldHGLx9d6RAshaYnni9NRbV+DS4iTtydgZhZjaiSGdJi+MXPZpaeE5yZZZVKsODl7uAEZ2bp1UgPrpgVfSXpHyRdlezvK2l0+UMzs2qkKH6rtGIe1ZoEHAOck+x/QG6FXzNrr0qwHtzuUMwQ9aiIOFLS/wBExAZJncscl5lVsyronRWjmAS3LXk1VwBI6k9J3qljZrWqGoafxSgmwf1f4EFggKSJ5FYX+WFZozKz6hUZmkWNiHskzSG3ZJKA0yPCb7Y3a8+y0oOTtC/wIfDb/LKIWFbOwMysimUlwZF7g1bzy2e6AvsDi4ARZYzLzKpYZq7BRcRn8veTVUYuaOVwM7OqkfpJhoiYK+lz5QjGzGpEVnpwki7N2+0AHAm8U7aIzKy6ZWkWFeiZ97mB3DW5+8sTjpnVhCz04JIbfHtExGW7KR4zq3IiA5MMkjpGREOhpcvNrJ2qkQRX6GH75jdnzZM0TdJXJJ3ZvO2O4MysCpVwNRFJd0paI+mVvLJrJL0laV6yfSnvuyskLZa0SNKYtuov5hpcX2AduXcwNN8PF8ADRZxrZllUukmGu4CbgV99ovymiPhJfoGk4cB4cvfgDgKelHRwRDS2VnmhBDcgmUF9hY8TW7Ma6aCaWTmU6hpcRMyUNLTIw8cB90XEFmCJpMXAaOC51k4oNEStA3okW8+8z82bmbVXUeQG/STNztvOL7KFiyTNT4awfZKyfYDlecesSMpaVagHtyoiri0yGDNrL9K9VWttRIxK2cItwHVJK9eRe4Xp19hxFJkfTasKJbjKL8dpZlWpnLeJRMTq7e1ItwG/S3ZXAEPyDh0MrCxUV6Eh6sk7G6CZZVzxQ9TUJNXn7Z5Bbh4AYBowXlIXSfsDw/j4bo8WFXrx8/qdC8/Msq5Uj2pJuhc4gdy1uhXA1cAJkkaSS5FLSRb3iIgFkqYCC8k9VXVhoRlU8GsDzSytEr7ZPiLOaaH4jgLHTwQmFlu/E5yZpSJq5wK9E5yZpVcjd8I6wZlZajX/sL2ZWauc4MwskzK24KWZ2Y7cgzOzrPI1ODPLLic4M8sq9+DMLJuCUi54WVZOcGaWSiZeOmNm1ionODPLKkVtZDgnODNLp4SriZSbE5yZpeZrcGaWWX5Uy8yyyz04M8ukIt9aXw2c4MwsPSc4M8si3+hrZpmmptrIcE5wZpaO74Nrn/oP2splP1tGnwENRBM8cvdePHRHf35w61IGH7gFgO69Gtn0fh3fPvWQCkfbfl1xztMcO/zPbNi4B1/59y8DcOLhf+K8sXPYb+AGvnHTmby2vD8Avbp9xMSvTufQfdfw6IuHcOP9f1XJ0KtGu79NRNKdwF8DayLisHK1U00aG8Tkawex+OVu7NG9kZsfe525M3vyr98cuv2Y869ayaYPOlQuSOORFw7m/mdGcOW5T20ve/Ptvvzgl1/gsi/P3OHYrQ113PbIKA6o38AB9X4X+nY10oMr57+0u4CxZay/6qxf04nFL3cDYPOmOpYv7kq/+m15RwTH/+27PPVQn8oEaAC89OYg3v+w6w5lf17dh2Vren/q2I+2dmL+knq2NtTtpuhqg6K4rdLK1oOLiJmShpar/mo3cPBWDjxsM6/N7ba97LCjNrHhnY6sXNKlgpGZ7aIAauRh+4qPlSSdL2m2pNnb2FLpcEqia7dGrrx9KbdeNYgPN378f/4TT3+Xpx/qXbnAzEpETcVtbdYj3SlpjaRX8sr6Spou6Y3kZ5+8766QtFjSIklj2qq/4gkuIiZHxKiIGNWJ2u/Z1HUMrrx9KX94oA9/fLT39vIOdcGxX3qP/57Wu9VzzWpB831wJRqi3sWnL2VdDsyIiGHAjGQfScOB8cCI5JxJkgpeO6h4gsuW4NIblrP8ja48MLn/Dt8cedwHLF/chbWrOlcoNrMSiSh+a7OqmAl8cvZmHDAl+TwFOD2v/L6I2BIRS4DFwOhC9fs2kRIaMXoTp5y9gTcXdmXS9EUA/PLf6pn1h158fpyHp9Ximn98kiMOXEXvHh/x4DV3c8ejo3j/wy5c8nd/pHePzfz4/Ed54629uPTW0wD4r6vuoXuXbXTs2Mhxn1nKJbecxtLV7XuiKMUEQj9Js/P2J0fE5DbOGRgRqwAiYpWkAUn5PsDzecetSMpaVc7bRO4FTiD3C64Aro6IO8rVXjVY8GIPxgw6vMXvbrhk390cjbXmml+d0mL5zJf3b7H8rGvPLWc4tan4BLc2IkaVqFWljaScs6jnlKtuM6usMt8CslpSfdJ7qwfWJOUrgCF5xw0GVhaqyNfgzCydABqjuG3nTAMmJJ8nAA/nlY+X1EXS/sAw4MVCFfkanJmlVqoeXEuXsoDrgamSzgOWAWcDRMQCSVOBhUADcGFENBaq3wnOzNIr0Y2+BS5lndzK8ROBicXW7wRnZqlVw2NYxXCCM7N0vFySmWWVAO38BMJu5QRnZqn5zfZmlk0eoppZdhX3nGk1cIIzs9Q8i2pm2eUenJllUngW1cyyrDbymxOcmaXn20TMLLuc4MwskwJo7y9+NrNsEuEhqpllWFNtdOGc4MwsHQ9RzSzLPEQ1s+xygjOzbPLD9maWVc1v1aoBTnBmlpqvwZlZdjnBmVkmBdDkBGdmmeRJBjPLMic4M8ukABpr41EGJzgzSykgSpPgJC0FPgAagYaIGCWpL/CfwFBgKfDliNiwM/V3KEmUZta+RBS3FefEiBgZEaOS/cuBGRExDJiR7O8UJzgzS6d5FrWYbeeMA6Ykn6cAp+9sRU5wZpZe6XpwATwhaY6k85OygRGxKtdMrAIG7GyYvgZnZukVP/zsJ2l23v7kiJict39sRKyUNACYLum1ksWIE5yZpRUBjY3FHr0279paC1XFyuTnGkkPAqOB1ZLqI2KVpHpgzc6G6iGqmaVXgiGqpO6SejZ/Br4AvAJMAyYkh00AHt7ZMN2DM7P0SnOj70DgQUmQy0W/jojHJM0Cpko6D1gGnL2zDTjBmVlKuzRD+nEtEW8Ch7dQvg44eZcbwAnOzNIKiBLd6FtuTnBmlp4f1TKzTIrwawPNLMO8moiZZVW4B2dm2eQFL80sq7xkuZllVQBR/KNaFeUEZ2bpROkWvCw3JzgzSy08RDWzzKqRHpyiimZDJL0D/LnScZRBP2BtpYOwVLL6d7ZfRPTflQokPUbuz6cYayNi7K60tyuqKsFllaTZhdbEsurjv7Ns8HpwZpZZTnBmlllOcLvH5LYPsSrjv7MM8DU4M8ss9+DMLLOc4Mwss5zgykjSWEmLJC2WdHml47G2SbpT0hpJr1Q6Ftt1TnBlIqkO+DnwRWA4cI6k4ZWNyopwF1CxG1OttJzgymc0sDgi3oyIrcB9wLgKx2RtiIiZwPpKx2Gl4QRXPvsAy/P2VyRlZrabOMGVj1oo8z05ZruRE1z5rACG5O0PBlZWKBazdskJrnxmAcMk7S+pMzAemFbhmMzaFSe4MomIBuAi4HHgVWBqRCyobFTWFkn3As8Bh0haIem8SsdkO8+PaplZZrkHZ2aZ5QRnZpnlBGdmmeUEZ2aZ5QRnZpnlBFdDJDVKmifpFUm/kdRtF+q6S9JZyefbCy0EIOkESX+5E20slfSpty+1Vv6JYzambOsaSd9LG6NlmxNcbdkcESMj4jBgK/DN/C+TFUxSi4ivR8TCAoecAKROcGaV5gRXu54BDkp6V09J+jXwsqQ6ST+WNEvSfEkXACjnZkkLJf0eGNBckaSnJY1KPo+VNFfSS5JmSBpKLpFekvQej5PUX9L9SRuzJB2bnLuXpCck/Y+kX9Dy87g7kPSQpDmSFkg6/xPf3ZDEMkNS/6TsQEmPJec8I+nQkvxpWib5zfY1SFJHcuvMPZYUjQYOi4glSZJ4LyI+J6kL8EdJTwBHAIcAnwEGAguBOz9Rb3/gNuD4pK6+EbFe0q3Axoj4SXLcr4GbIuJZSfuSe1rjL4CrgWcj4lpJpwE7JKxWfC1pYw9glqT7I2Id0B2YGxHflXRVUvdF5F4G882IeEPSUcAk4KSd+GO0dsAJrrbsIWle8vkZ4A5yQ8cXI2JJUv4F4LPN19eAPYFhwPHAvRHRCKyU9IcW6j8amNlcV0S0ti7aKcBwaXsHrZeknkkbZybn/l7ShiJ+p4slnZF8HpLEug5oAv4zKb8beEBSj+T3/U1e212KaMPaKSe42rI5IkbmFyT/0DflFwH/JyIe/8RxX6Lt5ZpUxDGQu7RxTERsbiGWop/9k3QCuWR5TER8KOlpoGsrh0fS7ruf/DMwa42vwWXP48C3JHUCkHSwpO7ATGB8co2uHjixhXOfAz4vaf/k3L5J+QdAz7zjniA3XCQ5bmTycSZwblL2RaBPG7HuCWxIktuh5HqQzToAzb3Qvyc39H0fWCLp7KQNSTq8jTasHXOCy57byV1fm5u8OOUX5HrqDwJvAC8DtwD//ckTI+IdctfNHpD0Eh8PEX8LnNE8yQBcDIxKJjEW8vFs7r8Ax0uaS26ovKyNWB8DOkqaD1wHPJ/33SZghKQ55K6xXZuUnwucl8S3AC8DbwV4NREzyyz34Mwss5zgzCyznODMLLOc4Mwss5zgzCyznODMLLOc4Mwss/4/CgDWTh54qtgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(\"Confusion Matrix:\")\n", + "cm = confusion_matrix(y_test, y_pred)\n", + "ConfusionMatrixDisplay(confusion_matrix=cm).plot();\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "05ea49ab", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# #optional extra\n", + "# # Checking to see how the data was scaled\n", + "\n", + "# # Visualize histograms before scaling for 'Fe' (example)\n", + "# plt.figure(figsize=(12, 6))\n", + "# plt.subplot(1, 2, 1)\n", + "# plt.hist(X['Fe'], bins=50, color='blue', alpha=0.7)\n", + "# plt.title('Before Scaling: Fe')\n", + "\n", + "# # Visualize histograms after scaling for 'Fe'\n", + "# plt.subplot(1, 2, 2)\n", + "# plt.hist(X_scaled[:, X.columns.get_loc('Fe')], bins=50, color='green', alpha=0.7)\n", + "# plt.title('After Scaling: Fe')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "705d2ffd", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# #optional extra\n", + "# Feature Importance\n", + "# feature_importances = best_rf.feature_importances_\n", + "# feature_importance_df = pd.DataFrame({'Element': elements, 'Importance': feature_importances})\n", + "# feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)\n", + "\n", + "# plt.figure(figsize=(10, 6))\n", + "# sns.barplot(x='Importance', y='Element', data=feature_importance_df, color='steelblue')\n", + "# plt.title('Feature Importance for Predicting Class A vs Class B')\n", + "# plt.tight_layout()\n", + "# plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "989eba46", + "metadata": {}, + "source": [ + "### Aim 2: Predict labels " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "41f2f5e1", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predictions on Unlabeled Data:\n", + " Unique_ID holeid Predicted_Class\n", + "4004 A04920 SOLVE225W1 A\n", + "4005 A05729 SOLVE225W1 A\n", + "4006 A05270 SOLVE225W1 A\n", + "4007 A05634 SOLVE225W1 A\n", + "4008 A04689 SOLVE225W1 A\n", + "... ... ... ...\n", + "4072 A06469 SOLVE227 B\n", + "4073 A06256 SOLVE227 B\n", + "4074 A05766 SOLVE227 B\n", + "4075 A08273 SOLVE227 B\n", + "4076 A06155 SOLVE227 B\n", + "\n", + "[70 rows x 3 columns]\n" + ] + } + ], + "source": [ + "unlabeled_data_copy = unlabeled_data.copy()\n", + "\n", + "# Predict unlabeled data \n", + "X_new = unlabeled_data[elements]\n", + "X_new_scaled = scaler.transform(X_new)\n", + "unlabeled_data.loc[:, 'Predicted_Class'] = le.inverse_transform(best_rf.predict(X_new_scaled))\n", + "\n", + "print(\"Predictions on Unlabeled Data:\")\n", + "print(unlabeled_data[['Unique_ID', 'holeid', 'Predicted_Class']].head(70))\n", + "unlabeled_data.to_csv('./data/predictions_on_unlabeled_data.csv', index=False)\n" + ] + }, + { + "cell_type": "markdown", + "id": "23730f2b", + "metadata": {}, + "source": [ + "Results available on: predictions_on_unlabeled_data.csv" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "182b0d3e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predictions for '?' class:\n", + "Predicted_Class\n", + "B 390\n", + "A 295\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "# Filter the rows where the original 'Class' column is '?'\n", + "unlabeled_data_question = unlabeled_data[unlabeled_data['Class'] == '?']\n", + "\n", + "# Count how many were predicted to be 'A' and 'B'\n", + "class_counts_from_question = unlabeled_data_question['Predicted_Class'].value_counts()\n", + "\n", + "# Print out the prediction counts for 'A' and 'B' from '?'\n", + "print(f\"Predictions for '?' class:\")\n", + "print(class_counts_from_question)" + ] + }, + { + "cell_type": "markdown", + "id": "0eb795e5", + "metadata": {}, + "source": [ + "### Observations" + ] + }, + { + "cell_type": "markdown", + "id": "e8bcc086", + "metadata": {}, + "source": [ + "- The model performs well overall, with high accuracy and strong performance for Class A.\n", + "- Class B predictions are less accurate.\n", + "- The imbalance in the dataset (more Class A samples) is probably why the model performs better for Class A \n", + "- Pb has the highest feature importance (makes sense as this is a lead deposit)" + ] + }, + { + "cell_type": "markdown", + "id": "e8dc4640", + "metadata": {}, + "source": [ + "### Improvements: \n", + "- Handle missing values more effectively (instead of dropping).\n", + "- Perform additional model validation (e.g., cross-validation, ROC curves).\n", + "- Experiment with other models (e.g., SVM, Gradient Boosting) for comparison." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}