diff --git a/.circleci/config.yml b/.circleci/config.yml index 70da2131..d66486f0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,15 +26,12 @@ test_settings: &test_settings chmod +x codecov ./codecov -F "$(basename $PWD | sed s/[^a-z]/_/g)" - - #################### # Jobs: see https://circleci.com/docs/2.0/jobs-steps/ #################### version: 2 jobs: - py310: <<: *test_settings docker: @@ -44,12 +41,12 @@ jobs: docker: - image: cimg/python:3.10 steps: - - checkout - - run: - name: Run linting - command: | - pip install ruff - ruff check src/ tests/ + - checkout + - run: + name: Run linting + command: | + pip install -r requirements.txt + ruff check src/ tests/ # Runs when the repository is tagged for release; see the workflows section # below for trigger logic. @@ -133,7 +130,7 @@ workflows: - deploy: filters: tags: - only: /[0-9]{4}.[0-9]{1,2}.[0-9]+/ # Calver: YYYY.M.MINOR + only: /[0-9]{4}.[0-9]{1,2}.[0-9]+/ # Calver: YYYY.M.MINOR branches: # Ignore all branches; this workflow should only run for tags. ignore: /.*/ diff --git a/requirements-dev.in b/requirements-dev.in index 07231a72..7e84a38a 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -226,7 +226,7 @@ more-itertools==10.4.0 # -c requirements.txt # jaraco-classes # jaraco-functools -mozilla-metric-config-parser==2024.7.1 +mozilla-metric-config-parser==2024.8.1 # via # -c requirements.txt # mozanalysis diff --git a/requirements-dev.txt b/requirements-dev.txt index 257f49a0..aad64973 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -953,9 +953,9 @@ more-itertools==10.4.0 \ # -r requirements-dev.in # jaraco-classes # jaraco-functools -mozilla-metric-config-parser==2024.7.1 \ - --hash=sha256:07ba32624cb9a38662bdff259dd4ec385bdc8ca4500b461d8a9025a736a25c6b \ - --hash=sha256:ad5169b4cec7b0fa013b3136584b7ee3d9b8a6a807d2eb71e5d9f269517f9d22 +mozilla-metric-config-parser==2024.8.1 \ + --hash=sha256:49fb6e67367809e3750108246e50dc1e68778564c2c158ba26bd4835b0b9c89c \ + --hash=sha256:cadc4ba9fb8399be0b857abb4bdd1f12ff1943159463fec0d128b71b2e72b554 # via -r requirements-dev.in mozilla-nimbus-schemas==2023.10.3 \ --hash=sha256:8771344a63b0d197dbebd8d7955ce8034d0f5063a13899f4da7d8a99803a76da \ diff --git a/requirements.in b/requirements.in index 0bfe022f..7eb0b03f 100644 --- a/requirements.in +++ b/requirements.in @@ -124,7 +124,7 @@ more-itertools==10.4.0 # via # jaraco-classes # jaraco-functools -mozilla-metric-config-parser==2024.7.1 +mozilla-metric-config-parser==2024.8.1 # via mozanalysis mozilla-nimbus-schemas==2023.10.3 # via mozilla-metric-config-parser diff --git a/requirements.txt b/requirements.txt index ca4c694b..ba439d6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -924,9 +924,9 @@ more-itertools==10.4.0 \ # -r requirements.in # jaraco-classes # jaraco-functools -mozilla-metric-config-parser==2024.7.1 \ - --hash=sha256:07ba32624cb9a38662bdff259dd4ec385bdc8ca4500b461d8a9025a736a25c6b \ - --hash=sha256:ad5169b4cec7b0fa013b3136584b7ee3d9b8a6a807d2eb71e5d9f269517f9d22 +mozilla-metric-config-parser==2024.8.1 \ + --hash=sha256:49fb6e67367809e3750108246e50dc1e68778564c2c158ba26bd4835b0b9c89c \ + --hash=sha256:cadc4ba9fb8399be0b857abb4bdd1f12ff1943159463fec0d128b71b2e72b554 # via -r requirements.in mozilla-nimbus-schemas==2023.10.3 \ --hash=sha256:8771344a63b0d197dbebd8d7955ce8034d0f5063a13899f4da7d8a99803a76da \ diff --git a/src/mozanalysis/config.py b/src/mozanalysis/config.py index fb447d05..aeebb254 100644 --- a/src/mozanalysis/config.py +++ b/src/mozanalysis/config.py @@ -6,6 +6,9 @@ from metric_config_parser.config import ConfigCollection +from mozanalysis.metrics import DataSource, Metric +from mozanalysis.segments import Segment, SegmentDataSource + METRIC_HUB_JETSTREAM_REPO = "https://github.com/mozilla/metric-hub/tree/main/jetstream" @@ -68,12 +71,11 @@ def check_configs_for_app(self, app_name: str) -> bool: return True return False - def get_metric(self, metric_slug: str, app_name: str): + def get_metric(self, metric_slug: str, app_name: str) -> Metric: """Load a metric definition for the given app. Returns a :class:`mozanalysis.metrics.Metric` instance. """ - from mozanalysis.metrics import Metric metric_definition = self.configs.get_metric_definition(metric_slug, app_name) if metric_definition is None: @@ -100,12 +102,11 @@ def get_metric(self, metric_slug: str, app_name: str): app_name=app_name, ) - def get_data_source(self, data_source_slug: str, app_name: str): + def get_data_source(self, data_source_slug: str, app_name: str) -> DataSource: """Load a data source definition for the given app. Returns a :class:`mozanalysis.metrics.DataSource` instance. """ - from mozanalysis.metrics import DataSource data_source_definition = self.configs.get_data_source_definition( data_source_slug, app_name @@ -120,26 +121,13 @@ def get_data_source(self, data_source_slug: str, app_name: str): f"Could not find application {app_name}, so data source {data_source_slug} could not be resolved" # noqa:E501 ) - return DataSource( - name=data_source_definition.name, - from_expr=data_source_definition.from_expression, - client_id_column=data_source_definition.client_id_column, - submission_date_column=data_source_definition.submission_date_column, - experiments_column_type=( - None - if data_source_definition.experiments_column_type == "none" - else data_source_definition.experiments_column_type - ), - default_dataset=data_source_definition.default_dataset, - app_name=app_name, - ) + return DataSource.from_mcp_data_source(data_source_definition, app_name) - def get_segment(self, segment_slug: str, app_name: str): + def get_segment(self, segment_slug: str, app_name: str) -> Segment: """Load a segment definition for the given app. Returns a :class:`mozanalysis.segments.Segment` instance. """ - from mozanalysis.segments import Segment segment_definition = self.configs.get_segment_definition(segment_slug, app_name) if segment_definition is None: @@ -165,12 +153,13 @@ def get_segment(self, segment_slug: str, app_name: str): app_name=app_name, ) - def get_segment_data_source(self, data_source_slug: str, app_name: str): + def get_segment_data_source( + self, data_source_slug: str, app_name: str + ) -> SegmentDataSource: """Load a segment data source definition for the given app. Returns a :class:`mozanalysis.segments.SegmentDataSource` instance. """ - from mozanalysis.segments import SegmentDataSource data_source_definition = self.configs.get_segment_data_source_definition( data_source_slug, app_name @@ -196,7 +185,9 @@ def get_segment_data_source(self, data_source_slug: str, app_name: str): app_name=app_name, ) - def get_outcome_metric(self, metric_slug: str, outcome_slug: str, app_name: str): + def get_outcome_metric( + self, metric_slug: str, outcome_slug: str, app_name: str + ) -> Metric: """Load a metric definition from an outcome defined for the given app. Parametrized metrics are not supported, since they may not be defined outside @@ -230,18 +221,21 @@ class MinimalConfiguration: summaries = metric_definition.resolve(outcome_spec, conf, self.configs) metric = summaries[0].metric + if metric.data_source is None: + raise ValueError(f"Unable to resolve DataSource for Metric {metric.name}") + return Metric( name=metric.name, select_expr=metric.select_expression, friendly_name=metric.friendly_name, description=metric.description, - data_source=metric.data_source, + data_source=DataSource.from_mcp_data_source(metric.data_source, app_name), bigger_is_better=metric.bigger_is_better, ) def get_outcome_data_source( self, data_source_slug: str, outcome_slug: str, app_name: str - ): + ) -> DataSource: """Load a data source definition from an outcome defined for the given app. Returns a :class:`mozanalysis.metrics.DataSource` instance. @@ -266,17 +260,8 @@ def get_outcome_data_source( + f" in outcome {outcome_slug}" ) - return DataSource( - name=data_source_definition.name, - from_expr=data_source_definition.from_expression, - client_id_column=data_source_definition.client_id_column, - submission_date_column=data_source_definition.submission_date_column, - experiments_column_type=( - None - if data_source_definition.experiments_column_type == "none" - else data_source_definition.experiments_column_type - ), - default_dataset=data_source_definition.default_dataset, + return DataSource.from_mcp_data_source( + data_source_definition, app_name=app_name ) diff --git a/src/mozanalysis/experiment.py b/src/mozanalysis/experiment.py index dd05f94a..34b3b001 100644 --- a/src/mozanalysis/experiment.py +++ b/src/mozanalysis/experiment.py @@ -5,21 +5,23 @@ import logging from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import attr +from metric_config_parser import AnalysisUnit +from typing_extensions import assert_never from mozanalysis import APPS from mozanalysis.bq import BigQueryContext, sanitize_table_name_for_bq from mozanalysis.config import ConfigLoader from mozanalysis.metrics import AnalysisBasis, DataSource, Metric +from mozanalysis.segments import Segment, SegmentDataSource +from mozanalysis.types import IncompatibleAnalysisUnit from mozanalysis.utils import add_days, date_sub, hash_ish if TYPE_CHECKING: from pandas import DataFrame - from mozanalysis.segments import Segment, SegmentDataSource - logger = logging.getLogger(__name__) @@ -114,6 +116,12 @@ class Experiment: app_id (str, optional): For a Glean app, the name of the BigQuery dataset derived from its app ID, like `org_mozilla_firefox`. app_name (str, optional): The Glean app name, like `fenix`. + analysis_unit (AnalysisUnit, optional): the "unit" of analysis, + which defines an experimental unit. For example: `CLIENT` + for mobile experiments or `GROUP` for desktop experiments. Is used + as the join key when building queries and sub-unit level data is + aggregated up to that level. Defaults to `AnalysisUnit.CLIENT` + unless specified Attributes: experiment_slug (str): Name of the study, used to identify @@ -129,11 +137,16 @@ class Experiment: before UTC midnight. """ - experiment_slug = attr.ib() + experiment_slug = attr.ib(type=str, validator=attr.validators.instance_of(str)) start_date = attr.ib() num_dates_enrollment = attr.ib(default=None) app_id = attr.ib(default=None) app_name = attr.ib(default=None) + analysis_unit = attr.ib( + type=AnalysisUnit, + default=AnalysisUnit.CLIENT, + validator=attr.validators.instance_of(AnalysisUnit), + ) def get_app_name(self): """ @@ -465,7 +478,9 @@ def build_enrollments_query( sample_size = sample_size or 100 enrollments_query = custom_enrollments_query or self._build_enrollments_query( - time_limits, enrollments_query_type, sample_size + time_limits, + enrollments_query_type, + sample_size, ) if exposure_signal: @@ -474,10 +489,14 @@ def build_enrollments_query( ) else: exposure_query = custom_exposure_query or self._build_exposure_query( - time_limits, enrollments_query_type + time_limits, + enrollments_query_type, ) - segments_query = self._build_segments_query(segment_list, time_limits) + segments_query = self._build_segments_query( + segment_list, + time_limits, + ) return f""" WITH raw_enrollments AS ({enrollments_query}), @@ -486,10 +505,10 @@ def build_enrollments_query( SELECT se.*, - e.* EXCEPT (client_id, branch) + e.* EXCEPT ({self.analysis_unit.value}, branch) FROM segmented_enrollments se LEFT JOIN exposures e - USING (client_id, branch) + USING ({self.analysis_unit.value}, branch) """ def build_metrics_query( @@ -563,9 +582,7 @@ def build_metrics_query( FROM `{enrollments_table}` e CROSS JOIN analysis_windows aw ), - exposures AS ( - {exposure_query} - ), + exposures AS ({exposure_query}), enrollments AS ( SELECT e.* EXCEPT (exposure_date, num_exposure_events), @@ -573,7 +590,7 @@ def build_metrics_query( x.num_exposure_events FROM exposures x RIGHT JOIN raw_enrollments e - USING (client_id, branch) + USING ({id_column}, branch) ) SELECT enrollments.*, @@ -586,6 +603,7 @@ def build_metrics_query( metrics_columns=",\n ".join(metrics_columns), metrics_joins="\n".join(metrics_joins), enrollments_table=enrollments_table, + id_column=self.analysis_unit.value, ) @staticmethod @@ -611,16 +629,27 @@ def _build_enrollments_query( ) -> str: """Return SQL to query a list of enrollments and their branches""" if enrollments_query_type == EnrollmentsQueryType.NORMANDY: - return self._build_enrollments_query_normandy(time_limits, sample_size) + return self._build_enrollments_query_normandy( + time_limits, + sample_size, + ) elif enrollments_query_type == EnrollmentsQueryType.GLEAN_EVENT: if not self.app_id: raise ValueError( "App ID must be defined for building Glean enrollments query" ) + if not self.analysis_unit == AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "Glean enrollments currently only support client_id analysis units" + ) return self._build_enrollments_query_glean_event( time_limits, self.app_id, sample_size ) elif enrollments_query_type == EnrollmentsQueryType.FENIX_FALLBACK: + if not self.analysis_unit == AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "Fenix fallback enrollments currently only support client_id analysis units" # noqa: E501 + ) return self._build_enrollments_query_fenix_baseline( time_limits, sample_size ) @@ -629,12 +658,18 @@ def _build_enrollments_query( raise ValueError( "App ID must be defined for building Cirrus enrollments query" ) + if not self.analysis_unit == AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "Cirrus enrollments currently only support client_id analysis units" + ) return self._build_enrollments_query_cirrus(time_limits, self.app_id) else: - raise ValueError + assert_never(enrollments_query_type) def _build_exposure_query( - self, time_limits: TimeLimits, exposure_query_type: EnrollmentsQueryType + self, + time_limits: TimeLimits, + exposure_query_type: EnrollmentsQueryType, ) -> str: """Return SQL to query a list of exposures and their branches""" if exposure_query_type == EnrollmentsQueryType.NORMANDY: @@ -644,8 +679,16 @@ def _build_exposure_query( raise ValueError( "App ID must be defined for building Glean exposures query" ) + if not self.analysis_unit == AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "Glean exposures currently only support client_id analysis units" + ) return self._build_exposure_query_glean_event(time_limits, self.app_id) elif exposure_query_type == EnrollmentsQueryType.FENIX_FALLBACK: + if not self.analysis_unit == AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "Fenix fallback exposures currently only support client_id analysis units" # noqa: E501 + ) return self._build_exposure_query_glean_event( time_limits, "org_mozilla_firefox" ) @@ -654,6 +697,10 @@ def _build_exposure_query( raise ValueError( "App ID must be defined for building Cirrus exposures query" ) + if not self.analysis_unit == AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "Cirrus exposures currently only support client_id analysis units" + ) return self._build_exposure_query_glean_event( time_limits, self.app_id, @@ -661,15 +708,21 @@ def _build_exposure_query( event_category="cirrus_events", ) else: - raise ValueError + assert_never(exposure_query_type) def _build_enrollments_query_normandy( - self, time_limits: TimeLimits, sample_size: int = 100 + self, + time_limits: TimeLimits, + sample_size: int = 100, ) -> str: """Return SQL to query enrollments for a normandy experiment""" + if (self.analysis_unit == AnalysisUnit.PROFILE_GROUP) and (sample_size < 100): + raise ValueError( + "Downsampling is not yet supported for group-level experiments" + ) return f""" SELECT - e.client_id, + e.{self.analysis_unit.value}, `mozfun.map.get_key`(e.event_map_values, 'branch') AS branch, MIN(e.submission_date) AS enrollment_date, @@ -683,7 +736,7 @@ def _build_enrollments_query_normandy( BETWEEN '{time_limits.first_enrollment_date}' AND '{time_limits.last_enrollment_date}' AND e.event_string_value = '{self.experiment_slug}' AND e.sample_id < {sample_size} - GROUP BY e.client_id, branch + GROUP BY e.{self.analysis_unit.value}, branch """ # noqa:E501 def _build_enrollments_query_fenix_baseline( @@ -699,6 +752,7 @@ def _build_enrollments_query_fenix_baseline( """ # Try to ignore users who enrolled early - but only consider a # 7 day window + return """ SELECT b.client_info.client_id AS client_id, @@ -741,6 +795,7 @@ def _build_enrollments_query_glean_event( ``ping_info.experiments`` to get a list of who is in what branch and when they enrolled. """ + return f""" SELECT events.client_info.client_id AS client_id, mozfun.map.get_key( @@ -774,6 +829,7 @@ def _build_enrollments_query_cirrus( ``ping_info.experiments`` to get a list of who is in what branch and when they enrolled. """ + return f""" SELECT mozfun.map.get_key(e.extra, "user_id") AS client_id, @@ -800,14 +856,14 @@ def _build_exposure_query_normandy(self, time_limits: TimeLimits) -> str: """Return SQL to query exposures for a normandy experiment""" return f""" SELECT - e.client_id, + e.{self.analysis_unit.value}, e.branch, min(e.submission_date) AS exposure_date, COUNT(e.submission_date) AS num_exposure_events FROM raw_enrollments re LEFT JOIN ( SELECT - client_id, + {self.analysis_unit.value}, `mozfun.map.get_key`(event_map_values, 'branchSlug') AS branch, submission_date FROM @@ -819,10 +875,10 @@ def _build_exposure_query_normandy(self, time_limits: TimeLimits) -> str: BETWEEN '{time_limits.first_enrollment_date}' AND '{time_limits.last_enrollment_date}' AND event_string_value = '{self.experiment_slug}' ) e - ON re.client_id = e.client_id AND + ON re.{self.analysis_unit.value} = e.{self.analysis_unit.value} AND re.branch = e.branch AND e.submission_date >= re.enrollment_date - GROUP BY e.client_id, e.branch + GROUP BY e.{self.analysis_unit.value}, e.branch """ # noqa: E501 def _build_exposure_query_glean_event( @@ -865,20 +921,21 @@ def _build_exposure_query_glean_event( def _build_metrics_query_bits( self, - metric_list: list[Metric], + metric_list: list[Metric | str], time_limits: TimeLimits, analysis_basis=AnalysisBasis.ENROLLMENTS, exposure_signal=None, ) -> tuple[list[str], list[str]]: """Return lists of SQL fragments corresponding to metrics.""" - metrics = [] + metrics: list[Metric] = [] for metric in metric_list: if isinstance(metric, str): metrics.append(ConfigLoader.get_metric(metric, self.get_app_name())) else: metrics.append(metric) - ds_metrics = self._partition_by_data_source(metrics) + ds_metrics = self._partition_metrics_by_data_source(metrics) + ds_metrics = cast(dict[DataSource, list[Metric]], ds_metrics) ds_metrics = { ds: metrics + ds.get_sanity_metrics(self.experiment_slug) for ds, metrics in ds_metrics.items() @@ -894,13 +951,15 @@ def _build_metrics_query_bits( self.experiment_slug, self.app_id, analysis_basis, + self.analysis_unit, exposure_signal, ) + metrics_joins.append( f""" LEFT JOIN ( - {query_for_metrics} - ) ds_{i} USING (client_id, branch, analysis_window_start, analysis_window_end) - """ + {query_for_metrics} + ) ds_{i} USING ({self.analysis_unit.value}, branch, analysis_window_start, analysis_window_end) + """ # noqa: E501 ) for m in ds_metrics[ds]: @@ -908,19 +967,30 @@ def _build_metrics_query_bits( return metrics_columns, metrics_joins - def _partition_by_data_source( - self, metric_or_segment_list: list[Metric] | list[Segment] - ) -> dict[DataSource | SegmentDataSource, list[Metric | Segment]]: + def _partition_segments_by_data_source( + self, segment_list: list[Segment] + ) -> dict[SegmentDataSource, list[Segment]]: + """Return a dict mapping segment data sources to segment lists.""" + data_sources = {s.data_source for s in segment_list} + + return { + ds: [s for s in segment_list if s.data_source == ds] for ds in data_sources + } + + def _partition_metrics_by_data_source( + self, metric_list: list[Metric] + ) -> dict[DataSource, list[Metric]]: """Return a dict mapping data sources to metric/segment lists.""" - data_sources = {m.data_source for m in metric_or_segment_list} + data_sources = {m.data_source for m in metric_list} return { - ds: [m for m in metric_or_segment_list if m.data_source == ds] - for ds in data_sources + ds: [m for m in metric_list if m.data_source == ds] for ds in data_sources } def _build_segments_query( - self, segment_list: list[Segment], time_limits: TimeLimits + self, + segment_list: list[Segment], + time_limits: TimeLimits, ) -> str: """Build a query adding segment columns to the enrollments view. @@ -935,7 +1005,7 @@ def _build_segments_query( # arrive with "how segments work" as their first question. segments_columns, segments_joins = self._build_segments_query_bits( - segment_list or [], time_limits + cast(list[Segment | str], segment_list) or [], time_limits ) return """ @@ -950,19 +1020,21 @@ def _build_segments_query( ) def _build_segments_query_bits( - self, segment_list: list[Segment], time_limits: TimeLimits + self, + segment_list: list[Segment | str], + time_limits: TimeLimits, ) -> tuple[list[str], list[str]]: """Return lists of SQL fragments corresponding to segments.""" # resolve segment slugs - segments = [] + segments: list[Segment] = [] for segment in segment_list: if isinstance(segment, str): segments.append(ConfigLoader.get_segment(segment, self.get_app_name())) else: segments.append(segment) - ds_segments = self._partition_by_data_source(segments) + ds_segments = self._partition_segments_by_data_source(segments) segments_columns = [] segments_joins = [] @@ -974,7 +1046,7 @@ def _build_segments_query_bits( segments_joins.append( f""" LEFT JOIN ( {query_for_segments} - ) ds_{i} USING (client_id, branch) + ) ds_{i} USING ({self.analysis_unit.value}, branch) """ ) @@ -1027,7 +1099,7 @@ class TimeLimits: first_date_data_required = attr.ib(type=str) last_date_data_required = attr.ib(type=str) - analysis_windows = attr.ib() # type: tuple[AnalysisWindow] + analysis_windows = attr.ib() # type: tuple[AnalysisWindow,...] @classmethod def for_single_analysis_window( @@ -1244,7 +1316,8 @@ class TimeSeriesResult: """ fully_qualified_table_name = attr.ib(type=str) - analysis_windows = attr.ib(type=list) + analysis_windows = attr.ib(type=tuple[AnalysisWindow, ...]) + analysis_unit = attr.ib(type=AnalysisUnit, default=AnalysisUnit.CLIENT) def get(self, bq_context: BigQueryContext, analysis_window) -> DataFrame: """Get the DataFrame for a specific analysis window. @@ -1353,8 +1426,11 @@ def _build_analysis_window_subset_query( This method returns SQL to query this table to obtain results in "the standard format" for a single analysis window. """ + except_clause = ( + f"{self.analysis_unit.value}, analysis_window_start, analysis_window_end" + ) return f""" - SELECT * EXCEPT (client_id, analysis_window_start, analysis_window_end) + SELECT * EXCEPT ({except_clause}) FROM {self.fully_qualified_table_name} WHERE analysis_window_start = {analysis_window.start} AND analysis_window_end = {analysis_window.end} diff --git a/src/mozanalysis/metrics.py b/src/mozanalysis/metrics.py index ad5bf189..bfaf551c 100644 --- a/src/mozanalysis/metrics.py +++ b/src/mozanalysis/metrics.py @@ -6,7 +6,14 @@ from enum import Enum from typing import TYPE_CHECKING +from metric_config_parser import AnalysisUnit +from typing_extensions import assert_never + +from mozanalysis.types import IncompatibleAnalysisUnit + if TYPE_CHECKING: + from metric_config_parser.data_source import DataSource as ParserDataSource + from mozanalysis.experiment import TimeLimits import logging @@ -23,6 +30,29 @@ class AnalysisBasis(Enum): EXPOSURES = "exposures" +# attr.s converters aren't compatible with mypy, define our own +# see: https://mypy.readthedocs.io/en/stable/additional_features.html#id1 +def client_id_column_converter(client_id_column: str | None) -> str: + if client_id_column is None: + return AnalysisUnit.CLIENT.value + else: + return client_id_column + + +def group_id_column_converter(group_id_column: str | None) -> str: + if group_id_column is None: + return AnalysisUnit.PROFILE_GROUP.value + else: + return group_id_column + + +def submission_date_column_converter(submission_date_column: str | None) -> str: + if submission_date_column is None: + return "submission_date" + else: + return submission_date_column + + @attr.s(frozen=True, slots=True) class DataSource: """Represents a table or view, from which Metrics may be defined. @@ -64,11 +94,27 @@ class DataSource: name = attr.ib(validator=attr.validators.instance_of(str)) _from_expr = attr.ib(validator=attr.validators.instance_of(str)) - experiments_column_type = attr.ib(default="simple", type=str) - client_id_column = attr.ib(default="client_id", type=str) - submission_date_column = attr.ib(default="submission_date", type=str) + experiments_column_type = attr.ib(default="simple", type=str | None) + client_id_column = attr.ib( + default=AnalysisUnit.CLIENT.value, + type=str, + validator=[attr.validators.instance_of(str), attr.validators.min_len(1)], + converter=client_id_column_converter, + ) + submission_date_column = attr.ib( + default="submission_date", + type=str, + validator=[attr.validators.instance_of(str), attr.validators.min_len(1)], + converter=submission_date_column_converter, + ) default_dataset = attr.ib(default=None, type=str | None) app_name = attr.ib(default=None, type=str | None) + group_id_column = attr.ib( + default=AnalysisUnit.PROFILE_GROUP.value, + type=str, + validator=[attr.validators.instance_of(str), attr.validators.min_len(1)], + converter=group_id_column_converter, + ) EXPERIMENT_COLUMN_TYPES = (None, "simple", "native", "glean") @@ -140,7 +186,8 @@ def build_query( time_limits: TimeLimits, experiment_slug: str, from_expr_dataset: str | None = None, - analysis_basis: str = AnalysisBasis.ENROLLMENTS, + analysis_basis: AnalysisBasis = AnalysisBasis.ENROLLMENTS, + analysis_unit: AnalysisUnit = AnalysisUnit.CLIENT, exposure_signal=None, ) -> str: """Return a nearly-self contained SQL query. @@ -148,9 +195,15 @@ def build_query( This query does not define ``enrollments`` but otherwise could be executed to query all metrics from this data source. """ - return """ - SELECT - e.client_id, + if analysis_unit == AnalysisUnit.CLIENT: + ds_id = self.client_id_column + elif analysis_unit == AnalysisUnit.PROFILE_GROUP: + ds_id = self.group_id_column + else: + assert_never(analysis_unit) + + return """SELECT + e.{id_column}, e.branch, e.analysis_window_start, e.analysis_window_end, @@ -159,21 +212,21 @@ def build_query( {metrics} FROM enrollments e LEFT JOIN {from_expr} ds - ON ds.{client_id} = e.client_id + ON ds.{ds_id} = e.{id_column} AND ds.{submission_date} BETWEEN '{fddr}' AND '{lddr}' AND ds.{submission_date} BETWEEN DATE_ADD(e.{date}, interval e.analysis_window_start day) AND DATE_ADD(e.{date}, interval e.analysis_window_end day) {ignore_pre_enroll_first_day} GROUP BY - e.client_id, + e.{id_column}, e.branch, e.num_exposure_events, e.exposure_date, e.analysis_window_start, e.analysis_window_end""".format( - client_id=self.client_id_column or "client_id", - submission_date=self.submission_date_column or "submission_date", + ds_id=ds_id, + submission_date=self.submission_date_column, from_expr=self.from_expr_for(from_expr_dataset), fddr=time_limits.first_date_data_required, lddr=time_limits.last_date_data_required, @@ -181,13 +234,16 @@ def build_query( f"{m.select_expr.format(experiment_slug=experiment_slug)} AS {m.name}" for m in metric_list ), - date="exposure_date" - if analysis_basis == AnalysisBasis.EXPOSURES and exposure_signal is None - else "enrollment_date", + date=( + "exposure_date" + if analysis_basis == AnalysisBasis.EXPOSURES and exposure_signal is None + else "enrollment_date" + ), ignore_pre_enroll_first_day=self.experiments_column_expr.format( - submission_date=self.submission_date_column or "submission_date", + submission_date=self.submission_date_column, experiment_slug=experiment_slug, ), + id_column=analysis_unit.value, ) def build_query_targets( @@ -198,6 +254,7 @@ def build_query_targets( analysis_length: int, from_expr_dataset: str | None = None, continuous_enrollment: bool = False, + analysis_unit: AnalysisUnit = AnalysisUnit.CLIENT, ) -> str: """Return a nearly-self contained SQL query that constructs the metrics query for targeting historical data without @@ -206,6 +263,11 @@ def build_query_targets( This query does not define ``targets`` but otherwise could be executed to query all metrics from this data source. """ + if analysis_unit != AnalysisUnit.CLIENT: + raise IncompatibleAnalysisUnit( + "`build_query_targets` currently only supports client_id analysis" + ) + return """ SELECT t.client_id, @@ -222,28 +284,23 @@ def build_query_targets( t.enrollment_date, t.analysis_window_start, t.analysis_window_end""".format( - client_id=self.client_id_column or "client_id", + client_id=self.client_id_column, from_expr=self.from_expr_for(from_expr_dataset), metrics=",\n ".join( f"{m.select_expr.format(experiment_name=experiment_name)} AS {m.name}" for m in metric_list ), - date_clause=""" - AND ds.{submission_date} BETWEEN '{fddr}' AND '{lddr}' - AND ds.{submission_date} BETWEEN + date_clause=( + f""" + AND ds.{self.submission_date_column} BETWEEN '{time_limits.first_date_data_required}' AND '{time_limits.last_date_data_required}' + AND ds.{self.submission_date_column} BETWEEN DATE_ADD(t.enrollment_date, interval t.analysis_window_start day) AND - DATE_ADD(t.enrollment_date, interval t.analysis_window_end day)""".format( - submission_date=self.submission_date_column or "submission_date", - fddr=time_limits.first_date_data_required, - lddr=time_limits.last_date_data_required, - ) - if not continuous_enrollment - else """AND ds.{submission_date} BETWEEN + DATE_ADD(t.enrollment_date, interval t.analysis_window_end day)""" # noqa: E501 + if not continuous_enrollment + else f"""AND ds.{self.submission_date_column} BETWEEN t.enrollment_date AND DATE_ADD(t.enrollment_date, interval {analysis_length} day) - """.format( - submission_date=self.submission_date_column or "submission_date", - analysis_length=analysis_length, + """ ), ) @@ -320,6 +377,30 @@ def get_sanity_metrics(self, experiment_slug: str) -> list[Metric]: else: raise ValueError + @classmethod + def from_mcp_data_source( + cls, + parser_data_source: ParserDataSource, + app_name: str | None = None, + group_id_column: str | None = AnalysisUnit.PROFILE_GROUP.value, + ) -> DataSource: + """metric-config-parser DataSource objects do not have an `app_name` + and do not, yet, have a group_id_column""" + return cls( + name=parser_data_source.name, + from_expr=parser_data_source.from_expression, + client_id_column=parser_data_source.client_id_column, + submission_date_column=parser_data_source.submission_date_column, + experiments_column_type=( + None + if parser_data_source.experiments_column_type == "none" + else parser_data_source.experiments_column_type + ), + default_dataset=parser_data_source.default_dataset, + app_name=app_name, + group_id_column=group_id_column, + ) + @attr.s(frozen=True, slots=True) class Metric: @@ -341,9 +422,11 @@ class Metric: used for validation """ - name = attr.ib(type=str) - data_source = attr.ib(type=DataSource) - select_expr = attr.ib(type=str) + name = attr.ib(type=str, validator=attr.validators.instance_of(str)) + data_source = attr.ib( + type=DataSource, validator=attr.validators.instance_of(DataSource) + ) + select_expr = attr.ib(type=str, validator=attr.validators.instance_of(str)) friendly_name = attr.ib(type=str | None, default=None) description = attr.ib(type=str | None, default=None) bigger_is_better = attr.ib(type=bool, default=True) diff --git a/src/mozanalysis/types.py b/src/mozanalysis/types.py index 2b7f425e..96ffe934 100644 --- a/src/mozanalysis/types.py +++ b/src/mozanalysis/types.py @@ -19,3 +19,7 @@ class Uplift(str, Enum): EstimatesByBranch = dict[BranchLabel, Estimates] CompareBranchesOutput = dict[ComparativeOption, EstimatesByBranch] + + +class IncompatibleAnalysisUnit(ValueError): + pass diff --git a/tests/test_experiment.py b/tests/test_experiment.py index aad5e91a..0c0e153e 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -1,3 +1,5 @@ +from textwrap import dedent + import pytest from helpers.cheap_lint import sql_lint # local helper file from helpers.config_loader_lists import ( @@ -8,11 +10,13 @@ klar_android_metrics, klar_ios_metrics, ) +from metric_config_parser import AnalysisUnit from mozanalysis.config import ApplicationNotFound, ConfigLoader from mozanalysis.experiment import ( AnalysisWindow, EnrollmentsQueryType, Experiment, + IncompatibleAnalysisUnit, TimeLimits, ) from mozanalysis.exposure import ExposureSignal @@ -277,8 +281,11 @@ def test_analysis_window_validates_end(): AnalysisWindow(5, 4) -def test_query_not_detectably_malformed(): - exp = Experiment("slug", "2019-01-01", 8) +@pytest.mark.parametrize( + "analysis_unit", [AnalysisUnit.CLIENT, AnalysisUnit.PROFILE_GROUP] +) +def test_query_not_detectably_malformed(analysis_unit: AnalysisUnit): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=analysis_unit) tl = TimeLimits.for_ts( first_enrollment_date="2019-01-01", @@ -296,6 +303,11 @@ def test_query_not_detectably_malformed(): sql_lint(enrollments_sql) assert "sample_id < None" not in enrollments_sql + if analysis_unit == AnalysisUnit.CLIENT: + assert "client_id" in enrollments_sql + elif analysis_unit == AnalysisUnit.PROFILE_GROUP: + assert "profile_group_id" in enrollments_sql + metrics_sql = exp.build_metrics_query( metric_list=[], time_limits=tl, @@ -304,9 +316,17 @@ def test_query_not_detectably_malformed(): sql_lint(metrics_sql) + if analysis_unit == AnalysisUnit.CLIENT: + assert "client_id" in metrics_sql + elif analysis_unit == AnalysisUnit.PROFILE_GROUP: + assert "profile_group_id" in metrics_sql -def test_megaquery_not_detectably_malformed(): - exp = Experiment("slug", "2019-01-01", 8) + +@pytest.mark.parametrize( + "analysis_unit", [AnalysisUnit.CLIENT, AnalysisUnit.PROFILE_GROUP] +) +def test_megaquery_not_detectably_malformed(analysis_unit: AnalysisUnit): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=analysis_unit) tl = TimeLimits.for_ts( first_enrollment_date="2019-01-01", @@ -321,6 +341,11 @@ def test_megaquery_not_detectably_malformed(): sql_lint(enrollments_sql) + if analysis_unit == AnalysisUnit.CLIENT: + assert "client_id" in enrollments_sql + elif analysis_unit == AnalysisUnit.PROFILE_GROUP: + assert "profile_group_id" in enrollments_sql + metrics_sql = exp.build_metrics_query( metric_list=desktop_metrics, time_limits=tl, @@ -329,9 +354,19 @@ def test_megaquery_not_detectably_malformed(): sql_lint(metrics_sql) + if analysis_unit == AnalysisUnit.CLIENT: + assert "client_id" in metrics_sql + elif analysis_unit == AnalysisUnit.PROFILE_GROUP: + assert "profile_group_id" in metrics_sql -def test_segments_megaquery_not_detectably_malformed(): - exp = Experiment("slug", "2019-01-01", 8) + +@pytest.mark.parametrize( + "analysis_unit", [AnalysisUnit.CLIENT, AnalysisUnit.PROFILE_GROUP] +) +def test_segments_megaquery_not_detectably_malformed( + analysis_unit: AnalysisUnit, +): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=analysis_unit) tl = TimeLimits.for_ts( first_enrollment_date="2019-01-01", @@ -848,3 +883,580 @@ def test_resolve_missing_column_names(): ) assert "None" not in metric_sql + + +def test_enrollments_query_explicit_client_id(): + exp = Experiment("slug", "2019-01-01", 8) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + enrollments_sql = exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.NORMANDY + ) + + sql_lint(enrollments_sql) + + expected = """ + WITH raw_enrollments AS ( +SELECT + e.client_id, + `mozfun.map.get_key`(e.event_map_values, 'branch') + AS branch, + MIN(e.submission_date) AS enrollment_date, + COUNT(e.submission_date) AS num_enrollment_events +FROM + `moz-fx-data-shared-prod.telemetry.events` e +WHERE + e.event_category = 'normandy' + AND e.event_method = 'enroll' + AND e.submission_date + BETWEEN '2019-01-01' AND '2019-01-08' + AND e.event_string_value = 'slug' + AND e.sample_id < 100 +GROUP BY e.client_id, branch + ), + segmented_enrollments AS ( +SELECT + raw_enrollments.*, + +FROM raw_enrollments + +), + exposures AS ( +SELECT + e.client_id, + e.branch, + min(e.submission_date) AS exposure_date, + COUNT(e.submission_date) AS num_exposure_events +FROM raw_enrollments re +LEFT JOIN ( + SELECT + client_id, + `mozfun.map.get_key`(event_map_values, 'branchSlug') AS branch, + submission_date + FROM + `moz-fx-data-shared-prod.telemetry.events` + WHERE + event_category = 'normandy' + AND (event_method = 'exposure' OR event_method = 'expose') + AND submission_date + BETWEEN '2019-01-01' AND '2019-01-08' + AND event_string_value = 'slug' +) e +ON re.client_id = e.client_id AND + re.branch = e.branch AND + e.submission_date >= re.enrollment_date +GROUP BY e.client_id, e.branch + ) + + SELECT + se.*, + e.* EXCEPT (client_id, branch) + FROM segmented_enrollments se + LEFT JOIN exposures e + USING (client_id, branch) +""" + + assert dedent(enrollments_sql) == expected + + metrics_sql = exp.build_metrics_query( + metric_list=[ + metric for metric in desktop_metrics if metric.name == "active_hours" + ], + time_limits=tl, + enrollments_table="enrollments", + ) + + sql_lint(metrics_sql) + + +def test_metrics_query_explicit_client_id(): + exp = Experiment("slug", "2019-01-01", 8) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + enrollments_sql = exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.NORMANDY + ) + + sql_lint(enrollments_sql) + + metrics_sql = exp.build_metrics_query( + metric_list=[ + metric for metric in desktop_metrics if metric.name == "active_hours" + ], + time_limits=tl, + enrollments_table="enrollments", + ) + + sql_lint(metrics_sql) + + expected = """ +WITH analysis_windows AS ( + (SELECT 0 AS analysis_window_start, 6 AS analysis_window_end) +UNION ALL +(SELECT 7 AS analysis_window_start, 13 AS analysis_window_end) +UNION ALL +(SELECT 14 AS analysis_window_start, 20 AS analysis_window_end) +UNION ALL +(SELECT 21 AS analysis_window_start, 27 AS analysis_window_end) +UNION ALL +(SELECT 28 AS analysis_window_start, 34 AS analysis_window_end) +UNION ALL +(SELECT 35 AS analysis_window_start, 41 AS analysis_window_end) +UNION ALL +(SELECT 42 AS analysis_window_start, 48 AS analysis_window_end) +), +raw_enrollments AS ( + -- needed by "exposures" sub query + SELECT + e.*, + aw.* + FROM `enrollments` e + CROSS JOIN analysis_windows aw +), +exposures AS ( + SELECT + * + FROM raw_enrollments e + ), +enrollments AS ( + SELECT + e.* EXCEPT (exposure_date, num_exposure_events), + x.exposure_date, + x.num_exposure_events + FROM exposures x + RIGHT JOIN raw_enrollments e + USING (client_id, branch) +) +SELECT + enrollments.*, + ds_0.active_hours +FROM enrollments + LEFT JOIN ( + SELECT + e.client_id, + e.branch, + e.analysis_window_start, + e.analysis_window_end, + e.num_exposure_events, + e.exposure_date, + COALESCE(SUM(active_hours_sum), 0) AS active_hours +FROM enrollments e + LEFT JOIN mozdata.telemetry.clients_daily ds + ON ds.client_id = e.client_id + AND ds.submission_date BETWEEN '2019-01-01' AND '2019-02-25' + AND ds.submission_date BETWEEN + DATE_ADD(e.enrollment_date, interval e.analysis_window_start day) + AND DATE_ADD(e.enrollment_date, interval e.analysis_window_end day) + +GROUP BY + e.client_id, + e.branch, + e.num_exposure_events, + e.exposure_date, + e.analysis_window_start, + e.analysis_window_end + ) ds_0 USING (client_id, branch, analysis_window_start, analysis_window_end)""" + + assert expected == dedent(metrics_sql.rstrip()) + + +def test_enrollments_query_explicit_group_id(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + enrollments_sql = exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.NORMANDY + ) + + sql_lint(enrollments_sql) + + expected = """ + WITH raw_enrollments AS ( +SELECT + e.profile_group_id, + `mozfun.map.get_key`(e.event_map_values, 'branch') + AS branch, + MIN(e.submission_date) AS enrollment_date, + COUNT(e.submission_date) AS num_enrollment_events +FROM + `moz-fx-data-shared-prod.telemetry.events` e +WHERE + e.event_category = 'normandy' + AND e.event_method = 'enroll' + AND e.submission_date + BETWEEN '2019-01-01' AND '2019-01-08' + AND e.event_string_value = 'slug' + AND e.sample_id < 100 +GROUP BY e.profile_group_id, branch + ), + segmented_enrollments AS ( +SELECT + raw_enrollments.*, + +FROM raw_enrollments + +), + exposures AS ( +SELECT + e.profile_group_id, + e.branch, + min(e.submission_date) AS exposure_date, + COUNT(e.submission_date) AS num_exposure_events +FROM raw_enrollments re +LEFT JOIN ( + SELECT + profile_group_id, + `mozfun.map.get_key`(event_map_values, 'branchSlug') AS branch, + submission_date + FROM + `moz-fx-data-shared-prod.telemetry.events` + WHERE + event_category = 'normandy' + AND (event_method = 'exposure' OR event_method = 'expose') + AND submission_date + BETWEEN '2019-01-01' AND '2019-01-08' + AND event_string_value = 'slug' +) e +ON re.profile_group_id = e.profile_group_id AND + re.branch = e.branch AND + e.submission_date >= re.enrollment_date +GROUP BY e.profile_group_id, e.branch + ) + + SELECT + se.*, + e.* EXCEPT (profile_group_id, branch) + FROM segmented_enrollments se + LEFT JOIN exposures e + USING (profile_group_id, branch) +""" + + assert dedent(enrollments_sql) == expected + + +def test_metrics_query_explicit_group_id(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + enrollments_sql = exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.NORMANDY + ) + + sql_lint(enrollments_sql) + + metrics_sql = exp.build_metrics_query( + metric_list=[ + metric for metric in desktop_metrics if metric.name == "active_hours" + ], + time_limits=tl, + enrollments_table="enrollments", + ) + + sql_lint(metrics_sql) + + expected = """ +WITH analysis_windows AS ( + (SELECT 0 AS analysis_window_start, 6 AS analysis_window_end) +UNION ALL +(SELECT 7 AS analysis_window_start, 13 AS analysis_window_end) +UNION ALL +(SELECT 14 AS analysis_window_start, 20 AS analysis_window_end) +UNION ALL +(SELECT 21 AS analysis_window_start, 27 AS analysis_window_end) +UNION ALL +(SELECT 28 AS analysis_window_start, 34 AS analysis_window_end) +UNION ALL +(SELECT 35 AS analysis_window_start, 41 AS analysis_window_end) +UNION ALL +(SELECT 42 AS analysis_window_start, 48 AS analysis_window_end) +), +raw_enrollments AS ( + -- needed by "exposures" sub query + SELECT + e.*, + aw.* + FROM `enrollments` e + CROSS JOIN analysis_windows aw +), +exposures AS ( + SELECT + * + FROM raw_enrollments e + ), +enrollments AS ( + SELECT + e.* EXCEPT (exposure_date, num_exposure_events), + x.exposure_date, + x.num_exposure_events + FROM exposures x + RIGHT JOIN raw_enrollments e + USING (profile_group_id, branch) +) +SELECT + enrollments.*, + ds_0.active_hours +FROM enrollments + LEFT JOIN ( + SELECT + e.profile_group_id, + e.branch, + e.analysis_window_start, + e.analysis_window_end, + e.num_exposure_events, + e.exposure_date, + COALESCE(SUM(active_hours_sum), 0) AS active_hours +FROM enrollments e + LEFT JOIN mozdata.telemetry.clients_daily ds + ON ds.profile_group_id = e.profile_group_id + AND ds.submission_date BETWEEN '2019-01-01' AND '2019-02-25' + AND ds.submission_date BETWEEN + DATE_ADD(e.enrollment_date, interval e.analysis_window_start day) + AND DATE_ADD(e.enrollment_date, interval e.analysis_window_end day) + +GROUP BY + e.profile_group_id, + e.branch, + e.num_exposure_events, + e.exposure_date, + e.analysis_window_start, + e.analysis_window_end + ) ds_0 USING (profile_group_id, branch, analysis_window_start, analysis_window_end)""" # noqa: E501 + + assert expected == dedent(metrics_sql.rstrip()) + + +def test_glean_group_id_incompatible(): + exp = Experiment( + "slug", + "2019-01-01", + 8, + analysis_unit=AnalysisUnit.PROFILE_GROUP, + app_id="test_app", + ) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises(IncompatibleAnalysisUnit): + exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.GLEAN_EVENT + ) + + +def test_glean_group_id_incompatible_exposures(): + exp = Experiment( + "slug", + "2019-01-01", + 8, + analysis_unit=AnalysisUnit.PROFILE_GROUP, + app_id="test_app", + ) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises(IncompatibleAnalysisUnit): + exp._build_exposure_query( + time_limits=tl, exposure_query_type=EnrollmentsQueryType.GLEAN_EVENT + ) + + +def test_glean_missing_app_id(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises( + ValueError, match="App ID must be defined for building Glean enrollments query" + ): + exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.GLEAN_EVENT + ) + + +def test_glean_exposures_missing_app_id(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises( + ValueError, match="App ID must be defined for building Glean exposures query" + ): + exp._build_exposure_query( + time_limits=tl, exposure_query_type=EnrollmentsQueryType.GLEAN_EVENT + ) + + +def test_cirrus_group_id_incompatible(): + exp = Experiment( + "slug", + "2019-01-01", + 8, + analysis_unit=AnalysisUnit.PROFILE_GROUP, + app_id="test_app", + ) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises(IncompatibleAnalysisUnit): + exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.CIRRUS + ) + + +def test_cirrus_group_id_incompatible_exposures(): + exp = Experiment( + "slug", + "2019-01-01", + 8, + analysis_unit=AnalysisUnit.PROFILE_GROUP, + app_id="test_app", + ) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises(IncompatibleAnalysisUnit): + exp._build_exposure_query( + time_limits=tl, exposure_query_type=EnrollmentsQueryType.CIRRUS + ) + + +def test_cirrus_missing_app_id(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises( + ValueError, match="App ID must be defined for building Cirrus enrollments query" + ): + exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.CIRRUS + ) + + +def test_cirrus_missing_app_id_exposures(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises( + ValueError, match="App ID must be defined for building Cirrus exposures query" + ): + exp._build_exposure_query( + time_limits=tl, exposure_query_type=EnrollmentsQueryType.CIRRUS + ) + + +def test_fenix_group_id_incompatible(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises(IncompatibleAnalysisUnit): + exp.build_enrollments_query( + time_limits=tl, enrollments_query_type=EnrollmentsQueryType.FENIX_FALLBACK + ) + + +def test_fenix_group_id_incompatible_exposures(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises(IncompatibleAnalysisUnit): + exp._build_exposure_query( + time_limits=tl, exposure_query_type=EnrollmentsQueryType.FENIX_FALLBACK + ) + + +def test_group_id_no_downsampling(): + exp = Experiment("slug", "2019-01-01", 8, analysis_unit=AnalysisUnit.PROFILE_GROUP) + + tl = TimeLimits.for_ts( + first_enrollment_date="2019-01-01", + last_date_full_data="2019-03-01", + time_series_period="weekly", + num_dates_enrollment=8, + ) + + with pytest.raises( + ValueError, + match="Downsampling is not yet supported for group-level experiments", + ): + exp.build_enrollments_query( + time_limits=tl, + enrollments_query_type=EnrollmentsQueryType.NORMANDY, + sample_size=99, + )