diff --git a/jetstream/analysis.py b/jetstream/analysis.py index e453f0cb..922520fc 100644 --- a/jetstream/analysis.py +++ b/jetstream/analysis.py @@ -12,9 +12,9 @@ import dask import dask.delayed import pytz -from dask.distributed import Client, LocalCluster +from dask.distributed import Client, LocalCluster, as_completed from dask.graph_manipulation import bind -from google.api_core.exceptions import BadRequest +from google.api_core.exceptions import BadRequest, GoogleAPICallError from google.cloud import bigquery from google.cloud.bigquery.job import WriteDisposition from google.cloud.exceptions import Conflict @@ -59,6 +59,24 @@ _dask_cluster = None +@dask.delayed +def _successful_metrics_dict( + metric_table_results: list[str], + all_metrics_by_ds: dict[str, list[str]], +) -> dict[str, list[str]]: + """Return only the data-source entries whose metric table was successfully computed.""" + try: + return { + name: metrics + for (name, metrics), result in zip( + all_metrics_by_ds.items(), metric_table_results, strict=True + ) + if result + } + except Exception: + return {} + + @attr.s(auto_attribs=True) class Analysis: """Wrapper for analysing experiments.""" @@ -244,6 +262,12 @@ def publish_view( metrics_dict: dict[str, list[str]] | None = None, ): assert self.config.experiment.normandy_slug is not None + if metrics_dict is not None and not metrics_dict: + logger.warning( + f"all metrics queries failed for {window_period.value} {analysis_basis};" + "skipping publish view..." + ) + return normalized_slug = bq_normalize_name(self.config.experiment.normandy_slug) view_name = "_".join([normalized_slug, window_period.table_suffix]) wildcard_expr = normalized_slug @@ -348,7 +372,17 @@ def publish_view( ) logger.debug(f"View ({view_name}) SQL: {sql}") - self.bigquery.execute(sql) + try: + self.bigquery.execute(sql) + except GoogleAPICallError as e: + logger.exception( + str(e), + extra={ + "experiment": self.config.experiment.normandy_slug, + "analysis_basis": analysis_basis, + }, + ) + raise @dask.delayed def calculate_metrics( @@ -428,9 +462,21 @@ def calculate_metrics( ) logger.info(metrics_sql) - results = self.bigquery.execute( - metrics_sql, res_table_name, experiment_slug=self.config.experiment.normandy_slug - ) + try: + results = self.bigquery.execute( + metrics_sql, + res_table_name, + experiment_slug=self.config.experiment.normandy_slug, + ) + except GoogleAPICallError as e: + logger.exception( + str(e), + extra={ + "experiment": self.config.experiment.normandy_slug, + "analysis_basis": analysis_basis, + }, + ) + raise logger.info( f"Metric query cost: {results.slot_millis * COST_PER_SLOT_MS}", ) @@ -451,7 +497,7 @@ def calculate_metric_for_ds( ) -> str: """ Calculate individual metric for a specific experiment. - Returns the BigQuery table results are written to. + Returns the BigQuery table results are written to, or empty str on failure. """ window = len(time_limits.analysis_windows) last_analysis_window = time_limits.analysis_windows[-1] @@ -517,18 +563,18 @@ def calculate_metric_for_ds( f"{results.slot_millis * COST_PER_SLOT_MS}" ) self._write_sql_output(res_table_name, metrics_sql) - except ValueError as e: + except (ValueError, GoogleAPICallError) as e: for metric in metrics: # log an exception for each failed metric because this is how we track errors logger.exception( str(e), - exc_info=e, extra={ "experiment": self.config.experiment.normandy_slug, "metric": metric.name, "analysis_basis": analysis_basis, }, ) + return "" return res_table_name @@ -545,58 +591,88 @@ def calculate_statistics( """ Run statistics on metric. """ - return ( - Summary.from_config(metric, analysis_length_dates, period) - .run( - segment_data, - self.config.experiment, - analysis_basis, - segment, + if segment_data is None: + return StatisticResultCollection.model_validate([]) + try: + return ( + Summary.from_config(metric, analysis_length_dates, period) + .run( + segment_data, + self.config.experiment, + analysis_basis, + segment, + ) + .set_segment(segment) + .set_analysis_basis(analysis_basis) ) - .set_segment(segment) - .set_analysis_basis(analysis_basis) - ) + except Exception as e: + logger.exception( + str(e), + extra={ + "experiment": self.config.experiment.normandy_slug, + "metric": metric.metric.name, + "statistic": metric.statistic.name(), + "analysis_basis": analysis_basis, + "segment": segment, + }, + ) + return StatisticResultCollection.model_validate([]) @dask.delayed def counts( self, segment_data: DataFrame, segment: str, analysis_basis: AnalysisBasis ) -> StatisticResultCollection: """Count and missing count statistics.""" - metric = "identity" - counts = ( - Count() - .transform( - segment_data, - metric, - "*", - self.config.experiment.normandy_slug, - analysis_basis, - segment, + if segment_data is None: + return StatisticResultCollection.model_validate([]) + try: + metric = "identity" + counts = ( + Count() + .transform( + segment_data, + metric, + "*", + self.config.experiment.normandy_slug, + analysis_basis, + segment, + ) + .set_segment(segment) + .set_analysis_basis(analysis_basis) ) - .set_segment(segment) - .set_analysis_basis(analysis_basis) - ) - other_counts = [ - StatisticResult( - metric=metric, - statistic="count", - parameter=None, - branch=b.slug, - comparison=None, - comparison_to_branch=None, - ci_width=None, - point=0, - lower=None, - upper=None, - segment=segment, - analysis_basis=analysis_basis, - ) - for b in self.config.experiment.branches - if b.slug not in {c.branch for c in counts} - ] + other_counts = [ + StatisticResult( + metric=metric, + statistic="count", + parameter=None, + branch=b.slug, + comparison=None, + comparison_to_branch=None, + ci_width=None, + point=0, + lower=None, + upper=None, + segment=segment, + analysis_basis=analysis_basis, + ) + for b in self.config.experiment.branches + if b.slug not in {c.branch for c in counts} + ] - return StatisticResultCollection.model_validate(counts.root + other_counts) + return StatisticResultCollection.model_validate(counts.root + other_counts) + except Exception as e: + logger.exception( + str(e), + extra={ + "experiment": self.config.experiment.normandy_slug, + "metric": "identity", + "statistic": "count", + "analysis_basis": analysis_basis, + "segment": segment, + }, + ) + return StatisticResultCollection.model_validate([]) @dask.delayed def subset_metric_table( @@ -609,14 +685,27 @@ def subset_metric_table( discrete_metrics: bool = False, ) -> DataFrame: """Pulls the metric data for this segment/analysis basis""" - + if not metric_table_name: + return None query = self._create_subset_metric_table_query( metric_table_name, segment, summary, analysis_basis, period, discrete_metrics ) logger.debug(f"subset_metric_table: {metric_table_name}, {summary.metric.name}\n{query}") - results: DataFrame = self.bigquery.execute(query).to_dataframe() + try: + results: DataFrame = self.bigquery.execute(query).to_dataframe() + except GoogleAPICallError as e: + logger.exception( + str(e), + extra={ + "experiment": self.config.experiment.normandy_slug, + "metric": summary.metric.name, + "analysis_basis": analysis_basis, + "segment": segment, + }, + ) + return None return results @@ -1096,7 +1185,21 @@ def save_statistics( # logger.error(f"Expected schema: {StatisticResult.bq_schema}") # logger.error(f"Data received: {segment_results}") ve = ValueError(error_msg) + logger.exception( + str(ve), + extra={ + "experiment": self.config.experiment.normandy_slug, + }, + ) raise ve from e + except Exception as e: + logger.exception( + str(e), + extra={ + "experiment": self.config.experiment.normandy_slug, + }, + ) + raise def run( self, @@ -1325,6 +1428,15 @@ def run( segment_data, segment, analysis_basis ).model_dump(warnings=False) + # done with analysis_basis: publish metrics view + # bind ensures publish_view runs after the metric table is written + results.append( + bind( + self.publish_view(period, analysis_basis=analysis_basis.value), + metrics_results, + ) + ) + else: # convert metric configurations to mozanalysis metrics summary_metrics: list[Summary] = [ @@ -1489,17 +1601,15 @@ def run( period, ).model_dump(warnings=False) - # done with analysis_basis: publish metric view - results.append( - bind( + # done with analysis_basis: publish metrics view for successful metrics only + filtered_dict = _successful_metrics_dict(metrics_results, all_metrics_by_ds) + results.append( self.publish_view( period, analysis_basis=analysis_basis.value, - metrics_dict=all_metrics_by_ds, - ), - [metrics_results], + metrics_dict=filtered_dict, + ) ) - ) # done with period: save statistics results to table result = self.save_statistics( @@ -1517,8 +1627,20 @@ def run( ) ) + # submit all tasks, and log errors for failed tasks result_futures = client.compute(results) - client.gather(result_futures) # block until futures have finished + for future in as_completed(result_futures): + if future.status != "error": + continue + try: + future.result() + except Exception: + logger.exception( + "A task failed during analysis with an unexpected exception", + extra={ + "experiment": self.config.experiment.normandy_slug, + }, + ) def enrollments_query(self, time_limits: TimeLimits, use_glean_ids: bool = False) -> str: """Returns the enrollments SQL query.""" diff --git a/jetstream/tests/test_analysis.py b/jetstream/tests/test_analysis.py index 2ed460d6..88d00613 100644 --- a/jetstream/tests/test_analysis.py +++ b/jetstream/tests/test_analysis.py @@ -8,11 +8,13 @@ import pytest import pytz import toml +from google.api_core.exceptions import GoogleAPICallError from metric_config_parser import segment from metric_config_parser.analysis import AnalysisSpec from metric_config_parser.data_source import DataSource from metric_config_parser.experiment import Branch, BucketConfig, Experiment from metric_config_parser.metric import AnalysisPeriod, Summary +from mozanalysis.experiment import TimeLimits from mozilla_nimbus_schemas.experimenter_apis.experiments import RandomizationUnit from mozilla_nimbus_schemas.jetstream import AnalysisBasis @@ -1181,6 +1183,279 @@ def test_create_subset_metric_table_query_complete_univariate(experiments): assert expected_query == actual_query +def test_run_continues_after_task_failure(experiments, monkeypatch, caplog): + import threading + + exp = experiments[0] + config = AnalysisSpec.default_for_experiment(exp, ConfigLoader.configs).resolve( + exp, ConfigLoader.configs + ) + analysis = Analysis("test", "test", config) + + monkeypatch.setattr("jetstream.analysis.Analysis.ensure_enrollments", Mock()) + monkeypatch.setattr("jetstream.analysis._dask_cluster", None) + + # Use threads (processes=False) so monkeypatches are visible inside dask workers. + original_local_cluster = jetstream.analysis.LocalCluster + monkeypatch.setattr( + "jetstream.analysis.LocalCluster", + lambda **kwargs: original_local_cluster(**{**kwargs, "processes": False, "n_workers": 1}), + ) + + mock_bq = MagicMock() + monkeypatch.setattr("jetstream.analysis.BigQueryClient", Mock(return_value=mock_bq)) + + # Raise on the first _table_name call that runs inside a dask worker thread. + # _table_name is also called from the main thread during graph construction + # (with the same arguments), so the thread check prevents a premature failure + # before any task has been submitted. + has_failed = threading.Event() + main_thread = threading.main_thread() + original_table_name = Analysis._table_name + + def patched_table_name( + self, window_period, window_index, analysis_basis=None, metric=None, statistics=False + ): + if ( + metric is not None + and not statistics + and not has_failed.is_set() + and threading.current_thread() is not main_thread + ): + has_failed.set() + raise RuntimeError(f"simulated failure for data source {metric}") + return original_table_name( + self, + window_period, + window_index, + analysis_basis=analysis_basis, + metric=metric, + statistics=statistics, + ) + + monkeypatch.setattr("jetstream.analysis.Analysis._table_name", patched_table_name) + + with caplog.at_level(logging.ERROR): + analysis.run( + current_date=dt.datetime(2020, 1, 10, tzinfo=pytz.utc), + dry_run=True, + discrete_metrics=True, + ) + + assert "simulated failure for data source" in caplog.text + assert "A task failed during analysis" in caplog.text + # publish_view ran for periods where all data sources succeeded + assert mock_bq.execute.called + + +@pytest.mark.parametrize("ErrorType", [GoogleAPICallError, ValueError]) +def test_calculate_metric_for_ds_returns_empty_on_specific_errors( + experiments, + monkeypatch, + caplog, + ErrorType, +): + """calculate_metric_for_ds returns '' (not raises) on GoogleAPICallError.""" + config = AnalysisSpec.default_for_experiment(experiments[0], ConfigLoader.configs).resolve( + experiments[0], ConfigLoader.configs + ) + + mock_bq = MagicMock() + mock_bq.execute.side_effect = ErrorType("simulated error") + monkeypatch.setattr("jetstream.analysis.BigQueryClient", Mock(return_value=mock_bq)) + + analysis = Analysis("test", "test", config) + + mock_exp = MagicMock() + mock_exp.build_metrics_query.return_value = "SELECT 1" + + time_limits = TimeLimits.for_single_analysis_window( + last_date_full_data="2020-01-09", + analysis_start_days=0, + analysis_length_dates=7, + first_enrollment_date="2019-12-01", + num_dates_enrollment=8, + ) + + mock_metric = MagicMock() + mock_metric.data_source.name = "test_ds" + + with caplog.at_level(logging.ERROR): + result = analysis.calculate_metric_for_ds( + mock_exp, + time_limits, + AnalysisPeriod.WEEK, + AnalysisBasis.ENROLLMENTS, + [mock_metric], + False, + ).compute(scheduler="synchronous") + + assert result == "" + assert "simulated error" in caplog.text + + +def test_calculate_metric_for_ds_raises_for_other_errors(experiments, monkeypatch): + """calculate_metric_for_ds returns '' (not raises) on GoogleAPICallError.""" + config = AnalysisSpec.default_for_experiment(experiments[0], ConfigLoader.configs).resolve( + experiments[0], ConfigLoader.configs + ) + + mock_bq = MagicMock() + mock_bq.execute.side_effect = HighPopulationException("simulated error") + monkeypatch.setattr("jetstream.analysis.BigQueryClient", Mock(return_value=mock_bq)) + + analysis = Analysis("test", "test", config) + + mock_exp = MagicMock() + mock_exp.build_metrics_query.return_value = "SELECT 1" + + time_limits = TimeLimits.for_single_analysis_window( + last_date_full_data="2020-01-09", + analysis_start_days=0, + analysis_length_dates=7, + first_enrollment_date="2019-12-01", + num_dates_enrollment=8, + ) + + mock_metric = MagicMock() + mock_metric.data_source.name = "test_ds" + + with pytest.raises(HighPopulationException): + analysis.calculate_metric_for_ds( + mock_exp, + time_limits, + AnalysisPeriod.WEEK, + AnalysisBasis.ENROLLMENTS, + [mock_metric], + False, + ).compute(scheduler="synchronous") + + +def test_subset_metric_table_returns_none_for_empty_name(experiments, monkeypatch): + """subset_metric_table returns None without calling BQ when metric_table_name is ''.""" + mock_bq = MagicMock() + monkeypatch.setattr("jetstream.analysis.BigQueryClient", Mock(return_value=mock_bq)) + + analysis = _empty_analysis(experiments) + + summary = MagicMock() + summary.metric.name = "test_metric" + summary.statistic.params = {} + summary.metric.depends_on = None + + result = analysis.subset_metric_table( + "", "all", summary, AnalysisBasis.ENROLLMENTS, AnalysisPeriod.WEEK, True + ).compute(scheduler="synchronous") + + assert result is None + mock_bq.execute.assert_not_called() + + +def test_subset_metric_table_returns_none_on_google_api_error(experiments, monkeypatch, caplog): + """subset_metric_table returns None (not raises) on GoogleAPICallError.""" + mock_bq = MagicMock() + mock_bq.execute.side_effect = GoogleAPICallError("simulated subset error") + monkeypatch.setattr("jetstream.analysis.BigQueryClient", Mock(return_value=mock_bq)) + monkeypatch.setattr( + "jetstream.analysis.Analysis._create_subset_metric_table_query", + Mock(return_value="SELECT 1"), + ) + + analysis = _empty_analysis(experiments) + + summary = MagicMock() + summary.metric.name = "test_metric" + + with caplog.at_level(logging.ERROR): + result = analysis.subset_metric_table( + "some_table", "all", summary, AnalysisBasis.ENROLLMENTS, AnalysisPeriod.WEEK, True + ).compute(scheduler="synchronous") + + assert result is None + assert "simulated subset error" in caplog.text + + +def test_counts_returns_empty_for_none_segment_data(experiments): + """counts returns an empty StatisticResultCollection when segment_data is None.""" + result = ( + _empty_analysis(experiments) + .counts(None, "all", AnalysisBasis.ENROLLMENTS) + .compute(scheduler="synchronous") + ) + + assert result.root == [] + + +def test_calculate_statistics_returns_empty_for_none_segment_data(experiments): + """calculate_statistics returns an empty StatisticResultCollection when segment_data is None.""" + summary = MagicMock() + + result = ( + _empty_analysis(experiments) + .calculate_statistics( + summary, None, "all", AnalysisBasis.ENROLLMENTS, 7, AnalysisPeriod.WEEK + ) + .compute(scheduler="synchronous") + ) + + assert result.root == [] + + +def test_run_continues_after_google_api_error(experiments, monkeypatch, caplog): + """Ensure that a GoogleAPICallError from one data source does not prevent + save_statistics from being called, and no downstream BQ calls are made with + empty string as a table name.""" + exp = experiments[0] + config = AnalysisSpec.default_for_experiment(exp, ConfigLoader.configs).resolve( + exp, ConfigLoader.configs + ) + analysis = Analysis("test", "test", config) + + monkeypatch.setattr("jetstream.analysis.Analysis.ensure_enrollments", Mock()) + monkeypatch.setattr("jetstream.analysis._dask_cluster", None) + + original_local_cluster = jetstream.analysis.LocalCluster + monkeypatch.setattr( + "jetstream.analysis.LocalCluster", + lambda **kwargs: original_local_cluster(**{**kwargs, "processes": False, "n_workers": 1}), + ) + + mock_bq = MagicMock() + + def bq_execute_side_effect(*args, **kwargs): + # Metric-table write calls have 2 positional args (sql, table_name). + # Raise for the search_clients data source only. + if len(args) >= 2 and "search_clients" in str(args[1]): + raise GoogleAPICallError("simulated failure for search_clients data source") + result = MagicMock() + result.slot_millis = 0 + return result + + mock_bq.execute.side_effect = bq_execute_side_effect + monkeypatch.setattr("jetstream.analysis.BigQueryClient", Mock(return_value=mock_bq)) + + with caplog.at_level(logging.ERROR): + analysis.run( + current_date=dt.datetime(2020, 1, 10, tzinfo=pytz.utc), + dry_run=False, + discrete_metrics=True, + ) + + # The legitimate failure was logged. + assert "simulated failure for search_clients data source" in caplog.text + + # save_statistics ran for the surviving periods (load_table_from_json is its BQ call). + assert mock_bq.load_table_from_json.called + + # No BQ execute call was made with an empty-identifier SQL pattern + for call in mock_bq.execute.call_args_list: + sql = str(call.args[0]) if call.args else "" + assert "``" not in sql, f"BQ execute called with empty identifier in SQL: {sql!r}" + + # No NoneType errors from counts receiving a None sentinel. + assert "NoneType" not in caplog.text + + def test_metric_slugs_adds_depends_on_metrics(experiments, monkeypatch): config = AnalysisSpec.default_for_experiment(experiments[0], ConfigLoader.configs).resolve( experiments[0], ConfigLoader.configs @@ -1220,6 +1495,7 @@ def test_metric_slugs_adds_depends_on_metrics(experiments, monkeypatch): monkeypatch.setattr("jetstream.analysis.bind", lambda x, deps: x) monkeypatch.setattr("jetstream.analysis.LocalCluster", MagicMock()) monkeypatch.setattr("jetstream.analysis.Client", MagicMock()) + monkeypatch.setattr("jetstream.analysis.as_completed", Mock(return_value=[])) metric_slugs = ["ratio_metric"] Analysis("test", "test", config).run( @@ -1395,6 +1671,7 @@ def capturing_bind(thing, deps): monkeypatch.setattr("jetstream.analysis.Analysis.publish_view", MagicMock()) monkeypatch.setattr("jetstream.analysis.LocalCluster", MagicMock()) monkeypatch.setattr("jetstream.analysis.Client", MagicMock()) + monkeypatch.setattr("jetstream.analysis.as_completed", Mock(return_value=[])) Analysis("test", "test", config).run( current_date=dt.datetime(2020, 1, 1, tzinfo=pytz.utc),