Skip to content
Merged
3 changes: 3 additions & 0 deletions datapipe/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class StepStatus:
class ComputeInput:
dt: DataTable
join_type: Literal["inner", "full"] = "full"
# Filtered join optimization: mapping from idx columns to dt columns
# Example: {"user_id": "id"} means filter dt by dt.id IN (idx.user_id)
join_keys: Optional[Dict[str, str]] = None


class ComputeStep:
Expand Down
240 changes: 211 additions & 29 deletions datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,23 +730,35 @@ def build_changed_idx_sql_v1(
order_by: Optional[List[str]] = None,
order: Literal["asc", "desc"] = "asc",
run_config: Optional[RunConfig] = None, # TODO remove
additional_columns: Optional[List[str]] = None,
) -> Tuple[Iterable[str], Any]:
"""
Args:
additional_columns: Дополнительные колонки для включения в результат (для filtered join)
"""
if additional_columns is None:
additional_columns = []

# Полный список колонок для SELECT (transform_keys + additional_columns)
all_select_keys = list(transform_keys) + additional_columns

all_input_keys_counts: Dict[str, int] = {}
for col in itertools.chain(*[inp.dt.primary_schema for inp in input_dts]):
all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1

inp_ctes = []
for inp in input_dts:
# Используем all_select_keys для включения дополнительных колонок
keys, cte = inp.dt.meta_table.get_agg_cte(
transform_keys=transform_keys,
transform_keys=all_select_keys,
filters_idx=filters_idx,
run_config=run_config,
)
inp_ctes.append(ComputeInputCTE(cte=cte, keys=keys, join_type=inp.join_type))

agg_of_aggs = _make_agg_of_agg(
ds=ds,
transform_keys=transform_keys,
transform_keys=all_select_keys,
ctes=inp_ctes,
agg_col="update_ts",
)
Expand All @@ -771,12 +783,14 @@ def build_changed_idx_sql_v1(
else: # len(transform_keys) > 1:
join_onclause_sql = sa.and_(*[agg_of_aggs.c[key] == out.c[key] for key in transform_keys])

# Важно: Включаем все колонки (transform_keys + additional_columns)
sql = (
sa.select(
# Нам нужно выбирать хотя бы что-то, чтобы не было ошибки при
# пустом transform_keys
sa.literal(1).label("_datapipe_dummy"),
*[sa.func.coalesce(agg_of_aggs.c[key], out.c[key]).label(key) for key in transform_keys],
*[sa.func.coalesce(agg_of_aggs.c[key], out.c[key]).label(key) if key in transform_keys
else agg_of_aggs.c[key].label(key) for key in all_select_keys if key in agg_of_aggs.c],
)
.select_from(agg_of_aggs)
.outerjoin(
Expand Down Expand Up @@ -811,7 +825,7 @@ def build_changed_idx_sql_v1(
*[sa.asc(sa.column(k)) for k in order_by],
out.c.priority.desc().nullslast(),
)
return (transform_keys, sql)
return (all_select_keys, sql)


# Обратная совместимость: алиас для старой версии
Expand Down Expand Up @@ -851,13 +865,22 @@ def build_changed_idx_sql_v2(
order_by: Optional[List[str]] = None,
order: Literal["asc", "desc"] = "asc",
run_config: Optional[RunConfig] = None,
additional_columns: Optional[List[str]] = None,
) -> Tuple[Iterable[str], Any]:
"""
Новая версия build_changed_idx_sql, использующая offset'ы для оптимизации.

Вместо FULL OUTER JOIN всех входных таблиц, выбираем только записи с
update_ts > offset для каждой входной таблицы, затем объединяем через UNION.

Args:
additional_columns: Дополнительные колонки для включения в результат (для filtered join)
"""
if additional_columns is None:
additional_columns = []

# Полный список колонок для SELECT (transform_keys + additional_columns)
all_select_keys = list(transform_keys) + additional_columns

# 1. Получить все offset'ы одним запросом для избежания N+1
offsets = offset_table.get_offsets_for_transformation(transformation_id)
Expand All @@ -867,43 +890,194 @@ def build_changed_idx_sql_v2(
offsets[inp.dt.name] = 0.0

# 2. Построить CTE для каждой входной таблицы с фильтром по offset
# Для таблиц с join_keys нужен обратный JOIN к основной таблице
changed_ctes = []

# Сначала находим "основную" таблицу - первую без join_keys
primary_inp = None
for inp in input_dts:
if not inp.join_keys:
primary_inp = inp
break

for inp in input_dts:
tbl = inp.dt.meta_table.sql_table
keys = [k for k in transform_keys if k in inp.dt.primary_keys]

if len(keys) == 0:
# Разделяем ключи на те, что есть в meta table, и те, что нужны из data table
meta_cols = [c.name for c in tbl.columns]
keys_in_meta = [k for k in all_select_keys if k in meta_cols]
keys_in_data_only = [k for k in all_select_keys if k not in meta_cols]

if len(keys_in_meta) == 0:
continue

transform_key_cols: List[Any] = [sa.column(k) for k in keys]
offset = offsets[inp.dt.name]

# SELECT transform_keys FROM input_meta WHERE update_ts > offset OR delete_ts > offset
# Включаем как обновленные, так и удаленные записи
changed_sql: Any = sa.select(*transform_key_cols).select_from(tbl).where(
sa.or_(
tbl.c.update_ts > offset,
sa.and_(
tbl.c.delete_ts.isnot(None),
tbl.c.delete_ts > offset
# ОБРАТНЫЙ JOIN для справочных таблиц с join_keys
# Когда изменяется справочная таблица, нужно найти все записи основной таблицы,
# которые на нее ссылаются
if inp.join_keys and primary_inp and hasattr(primary_inp.dt.table_store, 'data_table'):
# Справочная таблица изменилась - нужен обратный JOIN к основной
primary_data_tbl = primary_inp.dt.table_store.data_table

# Строим SELECT для всех колонок из all_select_keys основной таблицы
primary_data_cols = [c.name for c in primary_data_tbl.columns]
select_cols = [
primary_data_tbl.c[k] if k in primary_data_cols else sa.literal(None).label(k)
for k in all_select_keys
]

# Обратный JOIN: primary_table.join_key = reference_table.id
# Например: posts.user_id = profiles.id
# inp.join_keys = {'user_id': 'id'} означает:
# 'user_id' - колонка в основной таблице (posts)
# 'id' - колонка в справочной таблице (profiles)
join_conditions = []
for primary_col, ref_col in inp.join_keys.items():
if primary_col in primary_data_cols and ref_col in meta_cols:
join_conditions.append(primary_data_tbl.c[primary_col] == tbl.c[ref_col])

if len(join_conditions) == 0:
# Не можем построить JOIN - пропускаем эту таблицу
continue

join_condition = sa.and_(*join_conditions) if len(join_conditions) > 1 else join_conditions[0]

# SELECT primary_cols FROM reference_meta
# JOIN primary_data ON primary.join_key = reference.id
# WHERE reference.update_ts > offset
changed_sql = sa.select(*select_cols).select_from(
tbl.join(primary_data_tbl, join_condition)
).where(
sa.or_(
tbl.c.update_ts > offset,
sa.and_(
tbl.c.delete_ts.isnot(None),
tbl.c.delete_ts > offset
)
)
)
)

# Применить filters_idx и run_config
changed_sql = sql_apply_filters_idx_to_subquery(changed_sql, keys, filters_idx)
changed_sql = sql_apply_runconfig_filter(changed_sql, tbl, inp.dt.primary_keys, run_config)
# Применить filters и group by
changed_sql = sql_apply_filters_idx_to_subquery(changed_sql, all_select_keys, filters_idx)
# run_config фильтры применяются к справочной таблице
changed_sql = sql_apply_runconfig_filter(changed_sql, tbl, inp.dt.primary_keys, run_config)

if len(select_cols) > 0:
changed_sql = changed_sql.group_by(*select_cols)

changed_ctes.append(changed_sql.cte(name=f"{inp.dt.name}_changes"))
continue

# Если все ключи есть в meta table - используем простой запрос
if len(keys_in_data_only) == 0:
select_cols = [sa.column(k) for k in keys_in_meta]

# SELECT keys FROM input_meta WHERE update_ts > offset OR delete_ts > offset
changed_sql = sa.select(*select_cols).select_from(tbl).where(
sa.or_(
tbl.c.update_ts > offset,
sa.and_(
tbl.c.delete_ts.isnot(None),
tbl.c.delete_ts > offset
)
)
)

# Применить filters_idx и run_config
changed_sql = sql_apply_filters_idx_to_subquery(changed_sql, keys_in_meta, filters_idx)
changed_sql = sql_apply_runconfig_filter(changed_sql, tbl, inp.dt.primary_keys, run_config)

if len(select_cols) > 0:
changed_sql = changed_sql.group_by(*select_cols)
else:
# Есть колонки только в data table - нужен JOIN с data table
# Проверяем что у table_store есть data_table (для TableStoreDB)
if not hasattr(inp.dt.table_store, 'data_table'):
# Fallback: если нет data_table, используем только meta keys
select_cols = [sa.column(k) for k in keys_in_meta]
changed_sql = sa.select(*select_cols).select_from(tbl).where(
sa.or_(
tbl.c.update_ts > offset,
sa.and_(
tbl.c.delete_ts.isnot(None),
tbl.c.delete_ts > offset
)
)
)
changed_sql = sql_apply_filters_idx_to_subquery(changed_sql, keys_in_meta, filters_idx)
changed_sql = sql_apply_runconfig_filter(changed_sql, tbl, inp.dt.primary_keys, run_config)
if len(select_cols) > 0:
changed_sql = changed_sql.group_by(*select_cols)
else:
# JOIN meta table с data table для получения дополнительных колонок
data_tbl = inp.dt.table_store.data_table

# Проверяем какие дополнительные колонки действительно есть в data table
data_cols_available = [c.name for c in data_tbl.columns]
keys_in_data_available = [k for k in keys_in_data_only if k in data_cols_available]

if len(keys_in_data_available) == 0:
# Fallback: если нужных колонок нет в data table, используем только meta keys
select_cols = [sa.column(k) for k in keys_in_meta]
changed_sql = sa.select(*select_cols).select_from(tbl).where(
sa.or_(
tbl.c.update_ts > offset,
sa.and_(
tbl.c.delete_ts.isnot(None),
tbl.c.delete_ts > offset
)
)
)
changed_sql = sql_apply_filters_idx_to_subquery(changed_sql, keys_in_meta, filters_idx)
changed_sql = sql_apply_runconfig_filter(changed_sql, tbl, inp.dt.primary_keys, run_config)
if len(select_cols) > 0:
changed_sql = changed_sql.group_by(*select_cols)
changed_ctes.append(changed_sql.cte(name=f"{inp.dt.name}_changes"))
continue

# SELECT meta_keys, data_keys FROM meta JOIN data ON primary_keys
# WHERE update_ts > offset OR delete_ts > offset
select_cols = [tbl.c[k] for k in keys_in_meta] + [data_tbl.c[k] for k in keys_in_data_available]

# Строим JOIN condition по primary keys
if len(inp.dt.primary_keys) == 1:
join_condition = tbl.c[inp.dt.primary_keys[0]] == data_tbl.c[inp.dt.primary_keys[0]]
else:
join_condition = sa.and_(*[
tbl.c[pk] == data_tbl.c[pk] for pk in inp.dt.primary_keys
])

changed_sql = sa.select(*select_cols).select_from(
tbl.join(data_tbl, join_condition)
).where(
sa.or_(
tbl.c.update_ts > offset,
sa.and_(
tbl.c.delete_ts.isnot(None),
tbl.c.delete_ts > offset
)
)
)

# Применить filters_idx и run_config
all_keys = keys_in_meta + keys_in_data_available
changed_sql = sql_apply_filters_idx_to_subquery(changed_sql, all_keys, filters_idx)
changed_sql = sql_apply_runconfig_filter(changed_sql, tbl, inp.dt.primary_keys, run_config)

if len(transform_key_cols) > 0:
changed_sql = changed_sql.group_by(*transform_key_cols)
if len(select_cols) > 0:
changed_sql = changed_sql.group_by(*select_cols)

changed_ctes.append(changed_sql.cte(name=f"{inp.dt.name}_changes"))

# 3. Получить записи с ошибками из TransformMetaTable
# Важно: error_records должен иметь все колонки из all_select_keys для UNION
# Для additional_columns используем NULL, так как их нет в transform meta table
tr_tbl = meta_table.sql_table
error_records_sql: Any = sa.select(
*[sa.column(k) for k in transform_keys]
).select_from(tr_tbl).where(
error_select_cols: List[Any] = [sa.column(k) for k in transform_keys] + [
sa.literal(None).label(k) for k in additional_columns
]
error_records_sql: Any = sa.select(*error_select_cols).select_from(tr_tbl).where(
sa.or_(
tr_tbl.c.is_success != True, # noqa
tr_tbl.c.process_ts.is_(None)
Expand All @@ -922,15 +1096,22 @@ def build_changed_idx_sql_v2(
# 4. Объединить все изменения и ошибки через UNION
if len(changed_ctes) == 0:
# Если нет входных таблиц с изменениями, используем только ошибки
union_sql: Any = sa.select(*[error_records_cte.c[k] for k in transform_keys]).select_from(error_records_cte)
union_sql: Any = sa.select(*[error_records_cte.c[k] for k in all_select_keys]).select_from(error_records_cte)
else:
# UNION всех изменений и ошибок
# Важно: UNION должен включать все колонки из all_select_keys
# Для отсутствующих колонок используем NULL
union_parts = []
for cte in changed_ctes:
union_parts.append(sa.select(*[cte.c[k] for k in transform_keys if k in cte.c]).select_from(cte))
# Для каждой колонки из all_select_keys: берем из CTE если есть, иначе NULL
select_cols = [
cte.c[k] if k in cte.c else sa.literal(None).label(k)
for k in all_select_keys
]
union_parts.append(sa.select(*select_cols).select_from(cte))

union_parts.append(
sa.select(*[error_records_cte.c[k] for k in transform_keys]).select_from(error_records_cte)
sa.select(*[error_records_cte.c[k] for k in all_select_keys]).select_from(error_records_cte)
)

union_sql = sa.union(*union_parts)
Expand All @@ -947,10 +1128,11 @@ def build_changed_idx_sql_v2(
join_onclause_sql = sa.and_(*[union_cte.c[key] == tr_tbl.c[key] for key in transform_keys])

# Используем `out` для консистентности с v1
# Важно: Включаем все колонки (transform_keys + additional_columns)
out = (
sa.select(
sa.literal(1).label("_datapipe_dummy"),
*[union_cte.c[k] for k in transform_keys]
*[union_cte.c[k] for k in all_select_keys if k in union_cte.c]
)
.select_from(union_cte)
.outerjoin(tr_tbl, onclause=join_onclause_sql)
Expand All @@ -973,7 +1155,7 @@ def build_changed_idx_sql_v2(
tr_tbl.c.priority.desc().nullslast(),
)

return (transform_keys, out)
return (all_select_keys, out)


TRANSFORM_INPUT_OFFSET_SCHEMA: DataSchema = [
Expand Down
Loading