Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Remove pickled data from xcom table.

Revision ID: eed27faa34e3
Revises: 9fc3fc5de720
Create Date: 2024-11-18 18:41:50.849514

"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op
from sqlalchemy import text
from sqlalchemy.dialects.mysql import LONGBLOB

from airflow.migrations.db_types import TIMESTAMP, StringID

revision = "eed27faa34e3"
down_revision = "9fc3fc5de720"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply Remove pickled data from xcom table."""
# Summary of the change:
# 1. Create an archived table (`_xcom_archive`) to store the current "pickled" data in the xcom table
# 2. Extract and archive the pickled data using the condition
# 3. Delete the pickled data from the xcom table so that we can update the column type
# 4. Update the XCom.value column type to JSON from LargeBinary/LongBlob

conn = op.get_bind()
dialect = conn.dialect.name

# Create an archived table to store the current data
op.create_table(
"_xcom_archive",
sa.Column("dag_run_id", sa.Integer(), nullable=False, primary_key=True),
sa.Column("task_id", StringID(length=250), nullable=False, primary_key=True),
sa.Column("map_index", sa.Integer(), nullable=False, server_default=sa.text("-1"), primary_key=True),
sa.Column("key", StringID(length=512), nullable=False, primary_key=True),
sa.Column("dag_id", StringID(length=250), nullable=False),
sa.Column("run_id", StringID(length=250), nullable=False),
sa.Column("value", sa.LargeBinary().with_variant(LONGBLOB, "mysql"), nullable=True),
sa.Column("timestamp", TIMESTAMP(), nullable=False),
sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key"),
if_not_exists=True,
)

# Condition to detect pickled data for different databases
condition_templates = {
"postgresql": "get_byte(value, 0) = 128",
"mysql": "HEX(SUBSTRING(value, 1, 1)) = '80'",
"sqlite": "substr(value, 1, 1) = char(128)",
}

condition = condition_templates.get(dialect)
if not condition:
raise RuntimeError(f"Unsupported dialect: {dialect}")

# Key is a reserved keyword in MySQL, so we need to quote it
quoted_key = conn.dialect.identifier_preparer.quote("key")

# Archive pickled data using the condition
conn.execute(
text(
f"""
INSERT INTO _xcom_archive (dag_run_id, task_id, map_index, {quoted_key}, dag_id, run_id, value, timestamp)
SELECT dag_run_id, task_id, map_index, {quoted_key}, dag_id, run_id, value, timestamp
FROM xcom
WHERE value IS NOT NULL AND {condition}
"""
)
)

# Delete the pickled data from the xcom table so that we can update the column type
conn.execute(text(f"DELETE FROM xcom WHERE value IS NOT NULL AND {condition}"))

# Update the value column type to JSON
if dialect == "postgresql":
op.execute(
"""
ALTER TABLE xcom
ALTER COLUMN value TYPE JSONB
USING CASE
WHEN value IS NOT NULL THEN CAST(CONVERT_FROM(value, 'UTF8') AS JSONB)
ELSE NULL
END
"""
)
elif dialect == "mysql":
op.add_column("xcom", sa.Column("value_json", sa.JSON(), nullable=True))
op.execute("UPDATE xcom SET value_json = CAST(value AS CHAR CHARACTER SET utf8mb4)")
op.drop_column("xcom", "value")
op.alter_column("xcom", "value_json", existing_type=sa.JSON(), new_column_name="value")
elif dialect == "sqlite":
# Rename the existing `value` column to `value_old`
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.alter_column("value", new_column_name="value_old")

# Add the new `value` column with JSON type
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.add_column(sa.Column("value", sa.JSON(), nullable=True))

# Migrate data from `value_old` to `value`
conn.execute(
text(
"""
UPDATE xcom
SET value = json(CAST(value_old AS TEXT))
WHERE value_old IS NOT NULL
"""
)
)

# Drop the old `value_old` column
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.drop_column("value_old")


def downgrade():
"""Unapply Remove pickled data from xcom table."""
conn = op.get_bind()
dialect = conn.dialect.name

# Revert the value column back to LargeBinary
if dialect == "postgresql":
op.execute(
"""
ALTER TABLE xcom
ALTER COLUMN value TYPE BYTEA
USING CASE
WHEN value IS NOT NULL THEN CONVERT_TO(value::TEXT, 'UTF8')
ELSE NULL
END
"""
)
elif dialect == "mysql":
op.add_column("xcom", sa.Column("value_blob", LONGBLOB, nullable=True))
op.execute("UPDATE xcom SET value_blob = CAST(value AS BINARY);")
op.drop_column("xcom", "value")
op.alter_column("xcom", "value_blob", existing_type=LONGBLOB, new_column_name="value")

elif dialect == "sqlite":
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.alter_column("value", new_column_name="value_old")

with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.add_column(sa.Column("value", sa.LargeBinary, nullable=True))

conn.execute(
text(
"""
UPDATE xcom
SET value = CAST(value_old AS BLOB)
WHERE value_old IS NOT NULL
"""
)
)

with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.drop_column("value_old")
14 changes: 8 additions & 6 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@
from typing import TYPE_CHECKING, Any, Iterable, cast

from sqlalchemy import (
JSON,
Column,
ForeignKeyConstraint,
Index,
Integer,
LargeBinary,
PrimaryKeyConstraint,
String,
delete,
select,
text,
)
from sqlalchemy.dialects.mysql import LONGBLOB
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, reconstructor, relationship

Expand Down Expand Up @@ -80,7 +79,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)

value = Column(LargeBinary().with_variant(LONGBLOB, "mysql"))
value = Column(JSON)
timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

__table_args__ = (
Expand Down Expand Up @@ -453,9 +452,12 @@ def serialize_value(
dag_id: str | None = None,
run_id: str | None = None,
map_index: int | None = None,
) -> Any:
) -> str:
"""Serialize XCom value to JSON str."""
return json.dumps(value, cls=XComEncoder).encode("UTF-8")
try:
return json.dumps(value, cls=XComEncoder)
except (ValueError, TypeError):
raise ValueError("XCom value must be JSON serializable")

@staticmethod
def _deserialize_value(result: XCom, orm: bool) -> Any:
Expand All @@ -466,7 +468,7 @@ def _deserialize_value(result: XCom, orm: bool) -> Any:
if result.value is None:
return None

return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
return json.loads(result.value, cls=XComDecoder, object_hook=object_hook)

@staticmethod
def deserialize_value(result: XCom) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class MappedClassProtocol(Protocol):
"2.9.2": "686269002441",
"2.10.0": "22ed7efa9da2",
"2.10.3": "5f2621c13b39",
"3.0.0": "9fc3fc5de720",
"3.0.0": "eed27faa34e3",
}


Expand Down
2 changes: 1 addition & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3861,7 +3861,7 @@ class XComModelView(AirflowModelView):
permissions.ACTION_CAN_ACCESS_MENU,
]

search_columns = ["key", "value", "timestamp", "dag_id", "task_id", "run_id", "logical_date"]
search_columns = ["key", "timestamp", "dag_id", "task_id", "run_id", "logical_date"]
list_columns = ["key", "value", "timestamp", "dag_id", "task_id", "run_id", "map_index", "logical_date"]
base_order = ("dag_run_id", "desc")

Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
028d2fec22a15bbf5794e2fc7522eaf880a8b6293ead484780ef1a14e6cd9b48
7748eec981f977cc97b852d1fe982aebe24ec2d090ae8493a65cea101f9d42a5
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion docs/apache-airflow/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
| ``9fc3fc5de720`` (head) | ``2b47dc6bc8df`` | ``3.0.0`` | Add references between assets and triggers. |
| ``eed27faa34e3`` (head) | ``9fc3fc5de720`` | ``3.0.0`` | Remove pickled data from xcom table. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``9fc3fc5de720`` | ``2b47dc6bc8df`` | ``3.0.0`` | Add references between assets and triggers. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``2b47dc6bc8df`` | ``d03e4a635aa3`` | ``3.0.0`` | add dag versioning. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
Expand Down
4 changes: 4 additions & 0 deletions newsfragments/aip-72.significant.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ As part of this change the following breaking changes have occurred:

If you still need to use pickling, you can use a custom XCom backend that stores references in the metadata DB and
the pickled data can be stored in a separate storage like S3.

The ``value`` field in the XCom table has been changed to a ``JSON`` type via DB migration. The XCom records that
contains pickled data are archived in the ``_xcom_archive`` table. You can safely drop this table if you don't need
the data anymore.
22 changes: 17 additions & 5 deletions providers/src/airflow/providers/common/io/xcom/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from urllib.parse import urlsplit

import fsspec.utils
from packaging.version import Version

from airflow import __version__ as airflow_version
from airflow.configuration import conf
from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom
Expand All @@ -41,6 +43,10 @@
SECTION = "common.io"


AIRFLOW_VERSION = Version(airflow_version)
AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0")


def _get_compression_suffix(compression: str) -> str:
"""
Return the compression suffix for the given compression.
Expand Down Expand Up @@ -103,7 +109,7 @@ def _get_full_path(data: str) -> ObjectStoragePath:
raise ValueError(f"Not a valid url: {data}")

@staticmethod
def serialize_value(
def serialize_value( # type: ignore[override]
value: T,
*,
key: str | None = None,
Expand All @@ -114,16 +120,22 @@ def serialize_value(
) -> bytes | str:
# we will always serialize ourselves and not by BaseXCom as the deserialize method
# from BaseXCom accepts only XCom objects and not the value directly
s_val = json.dumps(value, cls=XComEncoder).encode("utf-8")
s_val = json.dumps(value, cls=XComEncoder)
s_val_encoded = s_val.encode("utf-8")

if compression := _get_compression():
suffix = f".{_get_compression_suffix(compression)}"
else:
suffix = ""

threshold = _get_threshold()
if threshold < 0 or len(s_val) < threshold: # Either no threshold or value is small enough.
return s_val
if threshold < 0 or len(s_val_encoded) < threshold: # Either no threshold or value is small enough.
if AIRFLOW_V_3_0_PLUS:
return s_val
else:
# TODO: Remove this branch once we drop support for Airflow 2
# This is for Airflow 2.10 where the value is expected to be bytes
return s_val_encoded

base_path = _get_base_path()
while True: # Safeguard against collisions.
Expand All @@ -138,7 +150,7 @@ def serialize_value(
p.parent.mkdir(parents=True, exist_ok=True)

with p.open(mode="wb", compression=compression) as f:
f.write(s_val)
f.write(s_val_encoded)
return BaseXCom.serialize_value(str(p))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/api_connexion/endpoints/test_xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids):
xcom = XCom(
dag_run_id=dagrun.id,
key=f"TEST_XCOM_KEY{i}",
value=b"null",
value="null",
run_id=self.run_id,
task_id=self.task_id,
dag_id=self.dag_id,
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def test_resolve_xcom_class(self):
def test_resolve_xcom_class_fallback_to_basexcom(self):
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"
assert cls.serialize_value([1]) == "[1]"

@conf_vars({("core", "xcom_backend"): "to be removed"})
def test_resolve_xcom_class_fallback_to_basexcom_no_config(self):
conf.remove_option("core", "xcom_backend")
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"
assert cls.serialize_value([1]) == "[1]"

@mock.patch("airflow.models.xcom.XCom.orm_deserialize_value")
def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize):
Expand Down Expand Up @@ -182,7 +182,7 @@ def serialize_value(
run_id=run_id,
map_index=map_index,
)
return json.dumps(value).encode("utf-8")
return json.dumps(value)

get_import.return_value = CurrentSignatureXCom

Expand Down