Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _is_pydantic(cls: Any) -> bool:
Checking is done by attributes as it is significantly faster than
using isinstance.
"""
return hasattr(cls, "__validators__") and hasattr(cls, "__fields__") and hasattr(cls, "dict")
return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and hasattr(cls, "model_fields_set")


def _register():
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ install_requires =
pendulum>=2.0
pluggy>=1.0
psutil>=4.2.0
pydantic>=1.10.0
pydantic>=2.3.0
pygments>=2.0.1
pyjwt>=2.0.0
python-daemon>=3.0.0
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve

_devel_only_amazon = [
"aws_xray_sdk",
"moto[cloudformation, glue]>=4.0",
"moto[glue]>=4.0",
# TODO: Remove the two below after https://github.com/aws/serverless-application-model/pull/3282
# gets released and add back "cloudformation" extra to moto above
"openapi-spec-validator >=0.2.8",
"jsonschema>=3.0",
f"mypy-boto3-rds>={_MIN_BOTO3_VERSION}",
f"mypy-boto3-redshift-data>={_MIN_BOTO3_VERSION}",
f"mypy-boto3-s3>={_MIN_BOTO3_VERSION}",
Expand Down
14 changes: 6 additions & 8 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
# under the License.
from __future__ import annotations

from pydantic import parse_raw_as

from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models.dataset import (
Expand Down Expand Up @@ -49,7 +47,7 @@ def test_serializing_pydantic_task_instance(session, create_task_instance):
json_string = pydantic_task_instance.json()
print(json_string)

deserialized_model = parse_raw_as(TaskInstancePydantic, json_string)
deserialized_model = TaskInstancePydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.state == State.RUNNING
assert deserialized_model.try_number == ti.try_number
Expand All @@ -68,7 +66,7 @@ def test_serializing_pydantic_dagrun(session, create_task_instance):
json_string = pydantic_dag_run.json()
print(json_string)

deserialized_model = parse_raw_as(DagRunPydantic, json_string)
deserialized_model = DagRunPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.state == State.RUNNING

Expand All @@ -85,7 +83,7 @@ def test_serializing_pydantic_local_task_job(session, create_task_instance):
json_string = pydantic_job.json()
print(json_string)

deserialized_model = parse_raw_as(JobPydantic, json_string)
deserialized_model = JobPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.job_type == "LocalTaskJob"
assert deserialized_model.state == State.RUNNING
Expand Down Expand Up @@ -139,17 +137,17 @@ def test_serializing_pydantic_dataset_event(session, create_task_instance, creat
json_string_dr = pydantic_dag_run.json()
print(json_string_dr)

deserialized_model1 = parse_raw_as(DatasetEventPydantic, json_string1)
deserialized_model1 = DatasetEventPydantic.model_validate_json(json_string1)
assert deserialized_model1.dataset.id == 1
assert deserialized_model1.dataset.uri == "one"
assert len(deserialized_model1.dataset.consuming_dags) == 1
assert len(deserialized_model1.dataset.producing_tasks) == 1

deserialized_model2 = parse_raw_as(DatasetEventPydantic, json_string2)
deserialized_model2 = DatasetEventPydantic.model_validate_json(json_string2)
assert deserialized_model2.dataset.id == 2
assert deserialized_model2.dataset.uri == "two"
assert len(deserialized_model2.dataset.consuming_dags) == 0
assert len(deserialized_model2.dataset.producing_tasks) == 0

deserialized_dr = parse_raw_as(DagRunPydantic, json_string_dr)
deserialized_dr = DagRunPydantic.model_validate_json(json_string_dr)
assert len(deserialized_dr.consumed_dataset_events) == 3