From a95cf908b70fbdfe33da23807e4813511dd8af75 Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Tue, 14 May 2024 11:51:03 +0200 Subject: [PATCH] Fix #1 - Multiply by a boolean to add nothing instead of an empty slice - Use dataclass.replace to properly replace fields in reshape step --- diffpass/base.py | 24 ++++++++++++------------ nbs/base.ipynb | 24 ++++++++++++------------ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/diffpass/base.py b/diffpass/base.py index baa14d6..2bf0500 100644 --- a/diffpass/base.py +++ b/diffpass/base.py @@ -8,7 +8,7 @@ # Stdlib imports from copy import deepcopy from typing import Optional, Any, Sequence, Union -from dataclasses import fields, dataclass +from dataclasses import fields, dataclass, replace # Progress bars from tqdm import tqdm @@ -621,6 +621,7 @@ def fit_bootstrap( break # Reshape results according to number of iterations performed + reshaped_fields = {} for field_name in available_fields: results_this_field = getattr(results, field_name) n_optimized_results_this_field = ( @@ -628,23 +629,22 @@ def fit_bootstrap( if can_optimize else field_to_length_so_far[field_name] ) + n_unoptimized_results_this_field = ( + len(results_this_field) - n_optimized_results_this_field + ) assert not n_optimized_results_this_field % n_iters_with_optimization n_in_each_optimized_iter = ( n_optimized_results_this_field // n_iters_with_optimization ) - setattr( - results, - field_name, - [ - results_this_field[ - j - * n_in_each_optimized_iter : (j + 1) - * n_in_each_optimized_iter - ] - for j in range(n_iters_with_optimization) + reshaped_fields[field_name] = [ + results_this_field[ + j * n_in_each_optimized_iter : (j + 1) * n_in_each_optimized_iter ] - + [results_this_field[n_optimized_results_this_field:]], + for j in range(n_iters_with_optimization) + ] + [results_this_field[n_optimized_results_this_field:]] * bool( + n_unoptimized_results_this_field ) + results = replace(results, **reshaped_fields) return results diff --git a/nbs/base.ipynb b/nbs/base.ipynb index 32ccc11..a6f30e2 100644 --- a/nbs/base.ipynb +++ b/nbs/base.ipynb @@ -52,7 +52,7 @@ "# Stdlib imports\n", "from copy import deepcopy\n", "from typing import Optional, Any, Sequence, Union\n", - "from dataclasses import fields, dataclass\n", + "from dataclasses import fields, dataclass, replace\n", "\n", "# Progress bars\n", "from tqdm import tqdm\n", @@ -689,6 +689,7 @@ " break\n", "\n", " # Reshape results according to number of iterations performed\n", + " reshaped_fields = {}\n", " for field_name in available_fields:\n", " results_this_field = getattr(results, field_name)\n", " n_optimized_results_this_field = (\n", @@ -696,24 +697,23 @@ " if can_optimize\n", " else field_to_length_so_far[field_name]\n", " )\n", + " n_unoptimized_results_this_field = (\n", + " len(results_this_field) - n_optimized_results_this_field\n", + " )\n", "\n", " assert not n_optimized_results_this_field % n_iters_with_optimization\n", " n_in_each_optimized_iter = (\n", " n_optimized_results_this_field // n_iters_with_optimization\n", " )\n", - " setattr(\n", - " results,\n", - " field_name,\n", - " [\n", - " results_this_field[\n", - " j\n", - " * n_in_each_optimized_iter : (j + 1)\n", - " * n_in_each_optimized_iter\n", - " ]\n", - " for j in range(n_iters_with_optimization)\n", + " reshaped_fields[field_name] = [\n", + " results_this_field[\n", + " j * n_in_each_optimized_iter : (j + 1) * n_in_each_optimized_iter\n", " ]\n", - " + [results_this_field[n_optimized_results_this_field:]],\n", + " for j in range(n_iters_with_optimization)\n", + " ] + [results_this_field[n_optimized_results_this_field:]] * bool(\n", + " n_unoptimized_results_this_field\n", " )\n", + " results = replace(results, **reshaped_fields)\n", "\n", " return results" ]