diff --git a/src/squlearn/qnn/base_qnn.py b/src/squlearn/qnn/base_qnn.py index fb763338..6f68ae6f 100644 --- a/src/squlearn/qnn/base_qnn.py +++ b/src/squlearn/qnn/base_qnn.py @@ -193,37 +193,7 @@ def fit(self, X, y, weights: np.ndarray = None) -> None: Labels weights: Weights for each data point """ - self.encoding_circuit._check_feature_consistency(X) - num_features = extract_num_features(X) - self._is_lowlevel_qnn_initialized = False - self._initialize_lowlevel_qnn(num_features) - - if self.param_ini is None or len(self.param_ini) != self._qnn.num_parameters: - self._param = self.encoding_circuit.generate_initial_parameters( - seed=self.parameter_seed, num_features=num_features - ) - else: - self._param = self.param_ini.copy() - - if ( - self.param_op_ini is None - or len(self.param_op_ini) != self._qnn.num_parameters_observable - ): - if isinstance(self.operator, list): - self._param_op = np.concatenate( - [ - operator.generate_initial_parameters(seed=self.parameter_seed + i + 1) - for i, operator in enumerate(self.operator) - ] - ) - else: - self._param_op = self.operator.generate_initial_parameters( - seed=self.parameter_seed + 1 - ) - else: - self._param_op = self.param_op_ini.copy() - self._is_fitted = False self._fit(X, y, weights) @@ -387,6 +357,35 @@ def _initialize_lowlevel_qnn(self, num_features: int | None = None) -> None: caching=self.caching, primitive=self.primitive, ) + + if self._is_fitted: + self._is_lowlevel_qnn_initialized = True + return + + if self.param_ini is None or len(self.param_ini) != self.encoding_circuit.num_parameters: + self._param = self.encoding_circuit.generate_initial_parameters( + seed=self.parameter_seed, num_features=num_features + ) + else: + self._param = self.param_ini.copy() + + if ( + self.param_op_ini is None + or len(self.param_op_ini) != self._qnn.num_parameters_observable + ): + if isinstance(self.operator, list): + self._param_op = np.concatenate( + [ + operator.generate_initial_parameters(seed=self.parameter_seed + i + 1) + for i, operator in enumerate(self.operator) + ] + ) + else: + self._param_op = self.operator.generate_initial_parameters( + seed=self.parameter_seed + 1 + ) + else: + self._param_op = self.param_op_ini.copy() self._is_lowlevel_qnn_initialized = True def _validate_input(self, X, y, incremental, reset):