From 0edefed5edec0fd6e851fca1b41909556b4d0ff3 Mon Sep 17 00:00:00 2001 From: juliensiems Date: Wed, 17 Apr 2024 12:46:02 +0100 Subject: [PATCH 1/3] Fix shape function plotting. --- mothernet/evaluation/plot_shape_function.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mothernet/evaluation/plot_shape_function.py b/mothernet/evaluation/plot_shape_function.py index 6e306b0e..f486eed7 100644 --- a/mothernet/evaluation/plot_shape_function.py +++ b/mothernet/evaluation/plot_shape_function.py @@ -9,8 +9,13 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray): sharex=True, sharey=True) for class_idx in range(num_classes): for feature_idx in range(num_features): - axs[class_idx][feature_idx].plot( - bin_edges[feature_idx], w[feature_idx][0:-1][:, class_idx] - w[feature_idx][0:-1][:, class_idx].mean()) + # Add one more edge to the beginning + bin_edge = np.concatenate( + [[bin_edges[feature_idx][0] - (bin_edges[feature_idx][1] - bin_edges[feature_idx][0])], + bin_edges[feature_idx]]) + # Compute the midpoints of the bin edges + axs[class_idx][feature_idx].step( + bin_edge, w[feature_idx][:, class_idx] - w[feature_idx][:, class_idx].mean(), where='pre') if class_idx == 0: axs[class_idx][feature_idx].set_title(f'Feature {feature_idx}') if feature_idx == 0: From 038d16086d6905d5b71151655418de574f6e6084 Mon Sep 17 00:00:00 2001 From: juliensiems Date: Wed, 17 Apr 2024 12:50:21 +0100 Subject: [PATCH 2/3] Fix merge conflict. --- mothernet/evaluation/plot_shape_function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mothernet/evaluation/plot_shape_function.py b/mothernet/evaluation/plot_shape_function.py index 04ba004c..80ff8097 100644 --- a/mothernet/evaluation/plot_shape_function.py +++ b/mothernet/evaluation/plot_shape_function.py @@ -12,7 +12,7 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray, feature_names=None class_range = [0] columns = min(int(np.sqrt(num_features)), 6) rows = int(np.ceil(num_features / columns)) - fig, axs = plt.subplots(rows, columns, figsize=(2*rows, 2*columns), + fig, axs = plt.subplots(columns, rows, figsize=(2*rows, 2*columns), sharey=True) for class_idx in class_range: for feature_idx in range(num_features): @@ -26,7 +26,7 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray, feature_names=None bin_edges[feature_idx]]) # Compute the midpoints of the bin edges ax.step( - bin_edge, w[feature_idx][:, class_idx] - w[feature_idx][:, class_idx].mean(), where='pre')>>>>>>> develop + bin_edge, w[feature_idx][:, class_idx] - w[feature_idx][:, class_idx].mean(), where='pre') if class_idx == 0: if feature_names is None: ax.set_title(f'Feature {feature_idx}') From a0987897f4f3bfa5dedec5ff30e09f9b55311a5b Mon Sep 17 00:00:00 2001 From: juliensiems Date: Wed, 17 Apr 2024 19:45:00 +0100 Subject: [PATCH 3/3] Small fix to the shape function plotting. LGTM --- mothernet/evaluation/plot_shape_function.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mothernet/evaluation/plot_shape_function.py b/mothernet/evaluation/plot_shape_function.py index a0b97089..cfe13f11 100644 --- a/mothernet/evaluation/plot_shape_function.py +++ b/mothernet/evaluation/plot_shape_function.py @@ -15,23 +15,20 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray, feature_names=None fig, axs = plt.subplots(rows, columns, figsize=(4*columns, 2*rows), sharey=True) feature_range = feature_subset if feature_subset is not None else range(num_features) - for class_idx in class_range: + for col_idx, class_idx in enumerate(class_range): for ax_idx, feature_idx in enumerate(feature_range): - weights_normalized = w[feature_idx][0:-1][:, class_idx] - w[feature_idx][0:-1].mean(axis=-1) if num_classes > 2: ax = axs[class_idx][ax_idx] else: ax = axs.ravel()[ax_idx] - ax.step(bin_edges[feature_idx], weights_normalized) if class_idx == 0 or num_classes == 2: ax = axs.ravel()[feature_idx] bin_edge = np.concatenate( [[bin_edges[feature_idx][0] - (bin_edges[feature_idx][1] - bin_edges[feature_idx][0])], bin_edges[feature_idx]]) - # Compute the midpoints of the bin edges - ax.step( - bin_edge, w[feature_idx][:, class_idx] - w[feature_idx][:, class_idx].mean(), where='pre') - if class_idx == 0: + weights_normalized = w[feature_idx][:, class_idx] - w[feature_idx].mean(axis=-1) + ax.step(bin_edge, weights_normalized, where='pre') + if col_idx == 0: if feature_names is None: ax.set_title(f'Feature {feature_idx}') else: