diff --git a/mothernet/evaluation/plot_shape_function.py b/mothernet/evaluation/plot_shape_function.py index 80a54cb0..cfe13f11 100644 --- a/mothernet/evaluation/plot_shape_function.py +++ b/mothernet/evaluation/plot_shape_function.py @@ -15,15 +15,20 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray, feature_names=None fig, axs = plt.subplots(rows, columns, figsize=(4*columns, 2*rows), sharey=True) feature_range = feature_subset if feature_subset is not None else range(num_features) - for class_idx in class_range: + for col_idx, class_idx in enumerate(class_range): for ax_idx, feature_idx in enumerate(feature_range): - weights_normalized = w[feature_idx][0:-1][:, class_idx] - w[feature_idx][0:-1].mean(axis=-1) if num_classes > 2: ax = axs[class_idx][ax_idx] else: ax = axs.ravel()[ax_idx] - ax.step(bin_edges[feature_idx], weights_normalized) if class_idx == 0 or num_classes == 2: + ax = axs.ravel()[feature_idx] + bin_edge = np.concatenate( + [[bin_edges[feature_idx][0] - (bin_edges[feature_idx][1] - bin_edges[feature_idx][0])], + bin_edges[feature_idx]]) + weights_normalized = w[feature_idx][:, class_idx] - w[feature_idx].mean(axis=-1) + ax.step(bin_edge, weights_normalized, where='pre') + if col_idx == 0: if feature_names is None: ax.set_title(f'Feature {feature_idx}') else: