From a82c15aa878ac9d2dbb8d223a4dd70e315d04e0d Mon Sep 17 00:00:00 2001 From: Moritz Date: Thu, 6 Nov 2025 11:46:51 +0100 Subject: [PATCH 1/2] fix parameter initialization --- src/squlearn/qnn/base_qnn.py | 56 +++++++++++++++++------------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/src/squlearn/qnn/base_qnn.py b/src/squlearn/qnn/base_qnn.py index fb763338..f3f4203b 100644 --- a/src/squlearn/qnn/base_qnn.py +++ b/src/squlearn/qnn/base_qnn.py @@ -193,37 +193,7 @@ def fit(self, X, y, weights: np.ndarray = None) -> None: Labels weights: Weights for each data point """ - self.encoding_circuit._check_feature_consistency(X) - num_features = extract_num_features(X) - self._is_lowlevel_qnn_initialized = False - self._initialize_lowlevel_qnn(num_features) - - if self.param_ini is None or len(self.param_ini) != self._qnn.num_parameters: - self._param = self.encoding_circuit.generate_initial_parameters( - seed=self.parameter_seed, num_features=num_features - ) - else: - self._param = self.param_ini.copy() - - if ( - self.param_op_ini is None - or len(self.param_op_ini) != self._qnn.num_parameters_observable - ): - if isinstance(self.operator, list): - self._param_op = np.concatenate( - [ - operator.generate_initial_parameters(seed=self.parameter_seed + i + 1) - for i, operator in enumerate(self.operator) - ] - ) - else: - self._param_op = self.operator.generate_initial_parameters( - seed=self.parameter_seed + 1 - ) - else: - self._param_op = self.param_op_ini.copy() - self._is_fitted = False self._fit(X, y, weights) @@ -387,6 +357,32 @@ def _initialize_lowlevel_qnn(self, num_features: int | None = None) -> None: caching=self.caching, primitive=self.primitive, ) + + + if self.param_ini is None or len(self.param_ini) != self.encoding_circuit.num_parameters: + self._param = self.encoding_circuit.generate_initial_parameters( + seed=self.parameter_seed, num_features=num_features + ) + else: + self._param = self.param_ini.copy() + + if ( + self.param_op_ini is None + or len(self.param_op_ini) != self._qnn.num_parameters_observable + ): + if isinstance(self.operator, list): + self._param_op = np.concatenate( + [ + operator.generate_initial_parameters(seed=self.parameter_seed + i + 1) + for i, operator in enumerate(self.operator) + ] + ) + else: + self._param_op = self.operator.generate_initial_parameters( + seed=self.parameter_seed + 1 + ) + else: + self._param_op = self.param_op_ini.copy() self._is_lowlevel_qnn_initialized = True def _validate_input(self, X, y, incremental, reset): From 647913990f99a70e6ae795e614837bef6446356b Mon Sep 17 00:00:00 2001 From: Moritz Date: Thu, 6 Nov 2025 17:05:05 +0100 Subject: [PATCH 2/2] return if fitted --- src/squlearn/qnn/base_qnn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/squlearn/qnn/base_qnn.py b/src/squlearn/qnn/base_qnn.py index f3f4203b..6f68ae6f 100644 --- a/src/squlearn/qnn/base_qnn.py +++ b/src/squlearn/qnn/base_qnn.py @@ -357,7 +357,10 @@ def _initialize_lowlevel_qnn(self, num_features: int | None = None) -> None: caching=self.caching, primitive=self.primitive, ) - + + if self._is_fitted: + self._is_lowlevel_qnn_initialized = True + return if self.param_ini is None or len(self.param_ini) != self.encoding_circuit.num_parameters: self._param = self.encoding_circuit.generate_initial_parameters(