From a1995e1aca4b678855dd8dd9d477ffc38c42f910 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Jun 2025 13:14:05 -0400 Subject: [PATCH 1/8] Add MySQL vector writer. --- .../schemaio-expansion-service/build.gradle | 2 + .../apache_beam/ml/rag/ingestion/cloudsql.py | 61 +- .../ml/rag/ingestion/cloudsql_it_test.py | 1056 ++++++++++++++--- .../apache_beam/ml/rag/ingestion/mysql.py | 220 ++++ .../ml/rag/ingestion/mysql_common.py | 454 +++++++ .../ml/rag/ingestion/test_utils.py | 22 + 6 files changed, 1674 insertions(+), 141 deletions(-) create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/mysql.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py diff --git a/sdks/java/extensions/schemaio-expansion-service/build.gradle b/sdks/java/extensions/schemaio-expansion-service/build.gradle index 15873d58e615..12ee92a9e109 100644 --- a/sdks/java/extensions/schemaio-expansion-service/build.gradle +++ b/sdks/java/extensions/schemaio-expansion-service/build.gradle @@ -64,6 +64,8 @@ dependencies { permitUnusedDeclared 'com.google.cloud:alloydb-jdbc-connector:1.2.0' implementation 'com.google.cloud.sql:postgres-socket-factory:1.25.0' permitUnusedDeclared 'com.google.cloud.sql:postgres-socket-factory:1.25.0' + implementation 'com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0' + permitUnusedDeclared 'com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0' testImplementation library.java.junit testImplementation library.java.mockito_core runtimeOnly ("org.xerial:sqlite-jdbc:3.49.1.0") diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py index 69ead961a763..d3710a7f70a4 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py @@ -21,12 +21,12 @@ from typing import List from typing import Optional +from apache_beam.ml.rag.ingestion import mysql +from apache_beam.ml.rag.ingestion import mysql_common +from apache_beam.ml.rag.ingestion import postgres +from apache_beam.ml.rag.ingestion import postgres_common from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig -from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder -from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig -from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec -from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution @dataclass @@ -138,7 +138,7 @@ def from_base_config(cls, config: LanguageConnectorConfig): return cls(**asdict(config)) -class CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig): +class CloudSQLPostgresVectorWriterConfig(postgres.PostgresVectorWriterConfig): def __init__( self, connection_config: LanguageConnectorConfig, @@ -146,10 +146,11 @@ def __init__( *, # pylint: disable=dangerous-default-value write_config: WriteConfig = WriteConfig(), - column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( - ), - conflict_resolution: Optional[ConflictResolution] = ConflictResolution( - on_conflict_fields=[], action='IGNORE')): + column_specs: List[postgres_common.ColumnSpec] = postgres_common. + ColumnSpecsBuilder.with_defaults().build(), + conflict_resolution: Optional[ + postgres_common.ConflictResolution] = postgres_common. + ConflictResolution(on_conflict_fields=[], action='IGNORE')): """Configuration for writing vectors to ClouSQL Postgres. Supports flexible schema configuration through column specifications and @@ -218,3 +219,45 @@ def __init__( table_name=table_name, column_specs=column_specs, conflict_resolution=conflict_resolution) + + +@dataclass +class _MySQLConnectorConfig(LanguageConnectorConfig): + def to_jdbc_url(self) -> str: + """Convert options to a properly formatted MySQL JDBC URL.""" + return self._build_jdbc_url( + socketFactory="com.google.cloud.sql.mysql.SocketFactory", + database_type="mysql") + + def additional_jdbc_args(self) -> Dict[str, List[Any]]: + return { + 'classpath': [ + "mysql:mysql-connector-java:8.0.22", + "com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0" + ] + } + + @classmethod + def from_base_config(cls, config: LanguageConnectorConfig): + return cls(**asdict(config)) + + +class CloudSQLMySQLVectorWriterConfig(mysql.MySQLVectorWriterConfig): + def __init__( + self, + connection_config: LanguageConnectorConfig, + table_name: str, + *, + write_config: WriteConfig = WriteConfig(), + # pylint: disable=dangerous-default-value + column_specs: List[mysql_common.ColumnSpec] = mysql_common. + ColumnSpecsBuilder.with_defaults().build(), + conflict_resolution: Optional[mysql_common.ConflictResolution] = None): + self.connector_config = _MySQLConnectorConfig.from_base_config( + connection_config) + super().__init__( + connection_config=self.connector_config.to_connection_config(), + write_config=write_config, + table_name=table_name, + column_specs=column_specs, + conflict_resolution=conflict_resolution) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py index 959e4cadb137..7ae49ba51823 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py @@ -15,209 +15,1001 @@ # limitations under the License. # +import json import logging import os import secrets import time import unittest +from dataclasses import dataclass +from typing import Any +from typing import List +from typing import Literal +from typing import Optional import pytest import sqlalchemy from google.cloud.sql.connector import Connector +from parameterized import parameterized from sqlalchemy import text import apache_beam as beam from apache_beam.io.jdbc import ReadFromJdbc +from apache_beam.ml.rag.ingestion import mysql_common +from apache_beam.ml.rag.ingestion import postgres_common from apache_beam.ml.rag.ingestion import test_utils from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform +from apache_beam.ml.rag.ingestion.cloudsql import CloudSQLMySQLVectorWriterConfig from apache_beam.ml.rag.ingestion.cloudsql import CloudSQLPostgresVectorWriterConfig from apache_beam.ml.rag.ingestion.cloudsql import LanguageConnectorConfig +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class DatabaseTestConfig: + """Database-specific test configuration.""" + database_type: Literal["postgresql", "mysql"] + writer_config_class: type + jdbc_driver: str + connector_module: Literal["pg8000", "pymysql"] + table_prefix: str + + password_env_var: str + username: str + database: str + instance_uri: str + + vector_column_type: str + metadata_column_type: str + common_module: Any + id_column_type: str = "VARCHAR(255)" + + +class DatabaseTestHelper: + """Helper class to manage database setup, connections, and operations.""" + def __init__(self, db_config: DatabaseTestConfig, table_suffix: str): + self.db_config = db_config + self.table_suffix = table_suffix + self.connector = None + self.engine = None + self.connection_config = None + + self.default_table_name = f"{db_config.table_prefix}{table_suffix}" + self.custom_table_name = f"{db_config.table_prefix}_custom_{table_suffix}" + self.metadata_conflicts_table = f"{db_config.table_prefix}_meta_conf_" \ + f"{table_suffix}" + + self._setup_read_queries() + + def _setup_read_queries(self): + if self.db_config.database_type == "postgresql": + self.read_queries = { + self.default_table_name: f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + """, + self.custom_table_name: f""" + SELECT + CAST(custom_id AS VARCHAR(255)), + CAST(embedding_vec AS text), + CAST(content_col AS VARCHAR(255)), + CAST(metadata AS text) + FROM {self.custom_table_name} + ORDER BY custom_id + """, + self.metadata_conflicts_table: f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(embedding AS text), + CAST(content AS VARCHAR(255)), + CAST(source AS VARCHAR(255)), + CAST(timestamp AS VARCHAR(255)) + FROM {self.metadata_conflicts_table} + ORDER BY timestamp, id + """ + } + elif self.db_config.database_type == "mysql": + self.read_queries = { + self.default_table_name: f""" + SELECT + CAST(id AS CHAR(255)) as id, + CAST(content AS CHAR(255)) as content, + vector_to_string(embedding) as embedding, + CAST(metadata AS CHAR(10000)) as metadata + FROM {self.default_table_name} + """, + self.custom_table_name: f""" + SELECT + CAST(custom_id AS CHAR(255)) as custom_id, + vector_to_string(embedding_vec) as embedding_vec, + CAST(content_col AS CHAR(255)) as content_col, + CAST(metadata AS CHAR(10000)) as metadata + FROM {self.custom_table_name} + ORDER BY custom_id + """, + self.metadata_conflicts_table: f""" + SELECT + CAST(id AS CHAR(255)) as id, + vector_to_string(embedding) as embedding, + CAST(content AS CHAR(255)) as content, + CAST(source AS CHAR(255)) as source, + CAST(timestamp AS CHAR(255)) as timestamp + FROM {self.metadata_conflicts_table} + ORDER BY timestamp, id + """ + } + + def get_read_query(self, table_name: str) -> str: + if table_name not in self.read_queries: + raise ValueError(f"No read query defined for table: {table_name}") + return self.read_queries[table_name] + + def setup_connection(self): + """Set up database connection and engine.""" + if not os.environ.get(self.db_config.password_env_var): + raise ValueError("Password environment variable not set.") + password = os.environ.get(self.db_config.password_env_var) + + self.connection_config = LanguageConnectorConfig( + username=self.db_config.username, + password=password, + database_name=self.db_config.database, + instance_name=self.db_config.instance_uri) + + self.connector = Connector(refresh_strategy="LAZY") -@pytest.mark.uses_gcp_java_expansion_service -@unittest.skipUnless( - os.environ.get('EXPANSION_JARS'), - "EXPANSION_JARS environment var is not provided, " - "indicating that jars have not been built") -@unittest.skipUnless( - os.environ.get('ALLOYDB_PASSWORD'), - "ALLOYDB_PASSWORD environment var is not provided") -class CloudSQLPostgresVectorWriterConfigTest(unittest.TestCase): - POSTGRES_TABLE_PREFIX = 'python_rag_postgres_' - - @classmethod - def _create_engine(cls): - """Create SQLAlchemy engine using Cloud SQL connector.""" def getconn(): - conn = cls.connector.connect( - cls.instance_uri, - "pg8000", - user=cls.username, - password=cls.password, - db=cls.database, + return self.connector.connect( + self.db_config.instance_uri, + self.db_config.connector_module, + user=self.db_config.username, + password=password, + db=self.db_config.database, ) - return conn - - engine = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, - ) - return engine - - @classmethod - def setUpClass(cls): - cls.database = os.environ.get('POSTGRES_DATABASE', 'postgres') - cls.username = os.environ.get('POSTGRES_USERNAME', 'postgres') - if not os.environ.get('ALLOYDB_PASSWORD'): - raise ValueError('ALLOYDB_PASSWORD env not set') - cls.password = os.environ.get('ALLOYDB_PASSWORD') - cls.instance_uri = os.environ.get( - 'POSTGRES_INSTANCE_URI', - 'apache-beam-testing:us-central1:beam-integration-tests') - - # Create unique table name suffix - cls.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3)) - - # Setup database connection - cls.connector = Connector(refresh_strategy="LAZY") - cls.engine = cls._create_engine() - def skip_if_dataflow_runner(self): - if self._runner and "dataflowrunner" in self._runner.lower(): - self.skipTest( - "Skipping some tests on Dataflow Runner to avoid bloat and timeouts") + dialect = "postgresql+pg8000" \ + if self.db_config.database_type == "postgresql" else "mysql+pymysql" + self.engine = sqlalchemy.create_engine(f"{dialect}://", creator=getconn) - def setUp(self): - self.write_test_pipeline = TestPipeline(is_integration_test=True) - self.read_test_pipeline = TestPipeline(is_integration_test=True) - self._runner = type(self.read_test_pipeline.runner).__name__ + def create_all_tables(self): + if not self.engine: + raise ValueError("Engine not initialized. Call setup_connection() first.") - self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ - f"{self.table_suffix}" + vector_type_large = self.db_config.vector_column_type.format( + size=test_utils.VECTOR_SIZE) + vector_type_small = self.db_config.vector_column_type.format(size=2) + metadata_type = self.db_config.metadata_column_type + id_type = self.db_config.id_column_type - # Create test table with self.engine.connect() as connection: - connection.execute( - text( - f""" + default_table_sql = f""" CREATE TABLE {self.default_table_name} ( - id TEXT PRIMARY KEY, - embedding VECTOR({test_utils.VECTOR_SIZE}), + id {id_type} PRIMARY KEY, + embedding {vector_type_large}, content TEXT, - metadata JSONB + metadata {metadata_type} ) - """)) - connection.commit() - _LOGGER = logging.getLogger(__name__) - _LOGGER.info("Created table %s", self.default_table_name) + """ + connection.execute(text(default_table_sql)) - def tearDown(self): - # Drop test table - with self.engine.connect() as connection: - connection.execute( - text(f"DROP TABLE IF EXISTS {self.default_table_name}")) + custom_table_sql = f""" + CREATE TABLE {self.custom_table_name} ( + custom_id {id_type} PRIMARY KEY, + embedding_vec {vector_type_small}, + content_col TEXT, + metadata {metadata_type} + ) + """ + connection.execute(text(custom_table_sql)) + + if self.db_config.database_type == "postgresql": + metadata_conflicts_sql = f""" + CREATE TABLE {self.metadata_conflicts_table} ( + id {id_type}, + source TEXT, + timestamp TIMESTAMP, + content TEXT, + embedding {vector_type_small}, + PRIMARY KEY (id), + UNIQUE (source, timestamp) + ) + """ + elif self.db_config.database_type == "mysql": + metadata_conflicts_sql = f""" + CREATE TABLE {self.metadata_conflicts_table} ( + id {id_type}, + source TEXT, + timestamp TIMESTAMP, + content TEXT, + embedding {vector_type_small}, + PRIMARY KEY (id), + UNIQUE KEY unique_source_timestamp (source(255), timestamp) + ) + """ + connection.execute(text(metadata_conflicts_sql)) connection.commit() - _LOGGER = logging.getLogger(__name__) - _LOGGER.info("Dropped table %s", self.default_table_name) - - @classmethod - def tearDownClass(cls): - if hasattr(cls, 'connector'): - cls.connector.close() - if hasattr(cls, 'engine'): - cls.engine.dispose() - - def test_language_connector(self): - """Test language connector.""" - self.skip_if_dataflow_runner() - connection_config = LanguageConnectorConfig( - username=self.username, - password=self.password, - database_name=self.database, - instance_name=self.instance_uri) - writer_config = CloudSQLPostgresVectorWriterConfig( - connection_config=connection_config, table_name=self.default_table_name) + def create_writer_config( + self, + table_name: Optional[str] = None, + column_specs=None, + conflict_resolution=None): + if not self.connection_config: + raise ValueError( + "Connection not initialized. Call setup_connection() first.") - # Create test chunks - num_records = 150 - sample_size = min(500, num_records // 2) - chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + table_name = table_name or self.default_table_name - self.write_test_pipeline.not_use_test_runner_api = True + kwargs = { + 'connection_config': self.connection_config, + 'table_name': table_name, + } - with self.write_test_pipeline as p: - _ = ( - p - | beam.Create(chunks) - | VectorDatabaseWriteTransform(writer_config)) + if column_specs is not None: + kwargs['column_specs'] = column_specs + if conflict_resolution is not None: + kwargs['conflict_resolution'] = conflict_resolution - self.read_test_pipeline.not_use_test_runner_api = True - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(content AS VARCHAR(255)), - CAST(embedding AS text), - CAST(metadata AS text) - FROM {self.default_table_name} - """ + return self.db_config.writer_config_class(**kwargs) - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=writer_config.connector_config.to_connection_config( - ).jdbc_url, - username=self.username, - password=self.password, - query=read_query, - classpath=writer_config.connector_config.additional_jdbc_args() - ['classpath'])) + def cleanup(self): + if self.engine: + table_names = [ + self.default_table_name, + self.custom_table_name, + self.metadata_conflicts_table + ] + + try: + with self.engine.connect() as connection: + for table_name in table_names: + connection.execute(text(f"DROP TABLE IF EXISTS {table_name}")) + connection.commit() + _LOGGER.info( + "Dropped %s tables: %s", + self.db_config.database_type, + ', '.join(table_names)) + except Exception as e: + _LOGGER.warning( + "Error dropping %s tables: %s", self.db_config.database_type, e) + + if self.connector: + try: + self.connector.close() + except Exception as e: + _LOGGER.warning("Error closing connector: %s", e) + + if self.engine: + try: + self.engine.dispose() + except Exception as e: + _LOGGER.warning("Error disposing engine: %s", e) + + +class PipelineVerificationHelper: + """Helper class for common pipeline verification patterns.""" + @staticmethod + def build_jdbc_params(helper: DatabaseTestHelper, table_name: str) -> dict: + """Build JDBC parameters dictionary for ReadFromJdbc.""" + writer_config = helper.create_writer_config(table_name) + + return { + 'table_name': table_name, + 'driver_class_name': helper.db_config.jdbc_driver, + 'jdbc_url': writer_config.connector_config.to_connection_config(). + jdbc_url, + 'username': helper.db_config.username, + 'password': helper.connection_config.password, + 'query': helper.get_read_query(table_name), + 'classpath': writer_config.connector_config.additional_jdbc_args() + ['classpath'] + } + @staticmethod + def verify_standard_operations( + pipeline, jdbc_params: dict, expected_chunks: List[Chunk]): + num_records = len(expected_chunks) + sample_size = min(500, num_records // 2) + + with pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + + # Count verification count_result = rows | "Count All" >> beam.combiners.Count.Globally() assert_that(count_result, equal_to([num_records]), label='count_check') + # Hash verification chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)) chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally( test_utils.HashingFn()) - assert_that( - chunk_hashes, - equal_to([test_utils.generate_expected_hash(num_records)]), - label='hash_check') + expected_hash = test_utils.generate_expected_hash(num_records) + assert_that(chunk_hashes, equal_to([expected_hash]), label='hash_check') - # Sample validation + # Sample validation - first N first_n = ( chunks | "Key on Index" >> beam.Map(test_utils.key_on_id) | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of( sample_size, key=lambda x: x[0], reverse=True) | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) - expected_first_n = test_utils.ChunkTestUtils.get_expected_values( - 0, sample_size) + expected_first_n = expected_chunks[:sample_size] assert_that( first_n, equal_to([expected_first_n]), label=f"first_{sample_size}_check") + # Sample validation - last N last_n = ( chunks | "Key on Index 2" >> beam.Map(test_utils.key_on_id) | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of( sample_size, key=lambda x: x[0]) | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) - expected_last_n = test_utils.ChunkTestUtils.get_expected_values( - num_records - sample_size, num_records)[::-1] + expected_last_n = expected_chunks[-sample_size:][::-1] assert_that( last_n, equal_to([expected_last_n]), label=f"last_{sample_size}_check") +# Database configurations +POSTGRES_CONFIG = DatabaseTestConfig( + database_type="postgresql", + writer_config_class=CloudSQLPostgresVectorWriterConfig, + jdbc_driver="org.postgresql.Driver", + connector_module="pg8000", + table_prefix="python_rag_postgres_", + password_env_var="ALLOYDB_PASSWORD", + username="postgres", + database="postgres", + instance_uri="apache-beam-testing:us-central1:beam-integration-tests", + vector_column_type="VECTOR({size})", + metadata_column_type="JSONB", + common_module=postgres_common) + +MYSQL_CONFIG = DatabaseTestConfig( + database_type="mysql", + writer_config_class=CloudSQLMySQLVectorWriterConfig, + jdbc_driver="com.mysql.cj.jdbc.Driver", + connector_module="pymysql", + table_prefix="python_rag_mysql_", + password_env_var="ALLOYDB_PASSWORD", + username="mysql", + database="embeddings", + instance_uri="apache-beam-testing:us-central1:beam-integration-tests-mysql", + vector_column_type="VECTOR({size}) USING VARBINARY", + metadata_column_type="JSON", + common_module=mysql_common) + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +class CloudSQLVectorWriterConfigTest(unittest.TestCase): + def setUp(self): + self.write_test_pipeline = TestPipeline(is_integration_test=True) + self.read_test_pipeline = TestPipeline(is_integration_test=True) + self.write_test_pipeline2 = TestPipeline(is_integration_test=True) + self.read_test_pipeline2 = TestPipeline(is_integration_test=True) + + self.write_test_pipeline.not_use_test_runner_api = True + self.read_test_pipeline.not_use_test_runner_api = True + self.write_test_pipeline2.not_use_test_runner_api = True + self.read_test_pipeline2.not_use_test_runner_api = True + self._runner = type(self.read_test_pipeline.runner).__name__ + + self.db_helpers = {} + self.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3)) + + # Set up database helpers + for config in [POSTGRES_CONFIG, MYSQL_CONFIG]: + helper = DatabaseTestHelper(config, self.table_suffix) + helper.setup_connection() + helper.create_all_tables() + self.db_helpers[config.database_type] = helper + _LOGGER.info("Successfully set up %s database", config.database_type) + + def tearDown(self): + for helper in self.db_helpers.values(): + helper.cleanup() + + def skip_if_dataflow_runner(self): + if self._runner and "dataflowrunner" in self._runner.lower(): + self.skipTest( + "Skipping some tests on Dataflow Runner to avoid bloat and timeouts") + + @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)]) + def test_default_config(self, db_config): + """Test basic write and read operations with default configuration. + + This test validates the most basic CloudSQL vector database functionality: + - Default table schema: id (VARCHAR), content (TEXT), embedding (VECTOR), + metadata (JSON/JSONB) + - Default column specifications (no customization) + - Default conflict resolution (IGNORE on primary key conflicts) + - Write chunks to database and read them back + - Verify data integrity through count, hash, and sample validation + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 150 + + # Create test data + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + # Write test + writer_config = helper.create_writer_config() + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Read and verify + self.read_test_pipeline.not_use_test_runner_api = True + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.default_table_name) + PipelineVerificationHelper.verify_standard_operations( + self.read_test_pipeline, jdbc_params, test_chunks) + + @parameterized.expand([ + (POSTGRES_CONFIG, "UPDATE", ["embedding", "content"]), + (MYSQL_CONFIG, "UPDATE", ["embedding", "content"]), + (POSTGRES_CONFIG, "IGNORE", None), + (MYSQL_CONFIG, "IGNORE", None), + (POSTGRES_CONFIG, "UPDATE_ALL", None), # Default update fields + (MYSQL_CONFIG, "UPDATE_ALL", None), + ]) + def test_conflict_resolution(self, db_config, action, update_fields): + """Test conflict resolution strategies when primary key conflicts occur. + + This test validates different approaches to handling duplicate primary + keys: + + UPDATE with specific fields: + - When duplicate ID encountered, update only specified fields (embedding, + content) + - Other fields (metadata) remain unchanged from original record + + IGNORE: + - When duplicate ID encountered, keep original record unchanged + + UPDATE_ALL (default update fields): + - When duplicate ID encountered, update ALL non-key fields + - This includes content, embedding, AND metadata + + Scenario for all strategies: + 1. Insert initial records + 2. Insert records with same IDs but different content/embeddings + 3. Verify final state matches expected conflict resolution behavior + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 20 + + common_module = db_config.common_module + if action == "IGNORE": + if db_config.database_type == "mysql": + conflict_resolution = common_module.ConflictResolution( + action="IGNORE", primary_key_field="id") + else: + conflict_resolution = None # Default behavior for PostgreSQL + elif action == "UPDATE": + if db_config.database_type == "postgresql": + conflict_resolution = common_module.ConflictResolution( + on_conflict_fields="id", + action="UPDATE", + update_fields=update_fields) + else: + conflict_resolution = common_module.ConflictResolution( + action="UPDATE", update_fields=update_fields) + else: # UPDATE_ALL + if db_config.database_type == "postgresql": + conflict_resolution = common_module.ConflictResolution( + on_conflict_fields="id", action="UPDATE") + else: + conflict_resolution = common_module.ConflictResolution(action="UPDATE") + + initial_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records) + writer_config = helper.create_writer_config( + conflict_resolution=conflict_resolution) + + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | "Write Initial" >> beam.Create(initial_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Write conflicting data + updated_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records, content_prefix="Updated", seed_multiplier=2) + + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p | "Write Conflicts" >> beam.Create(updated_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.default_table_name) + expected_chunks = updated_chunks if action != "IGNORE" else initial_chunks + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk) + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)]) + def test_custom_column_names_and_value_functions(self, db_config): + """Test completely custom column specifications with custom value + extraction. + + This test validates advanced customization of how chunk data is stored: + + Custom column names: + - custom_id (instead of 'id') + - embedding_vec (instead of 'embedding') + - content_col (instead of 'content') + + Custom value extraction functions: + - ID: Extract timestamp from metadata and prefix with "timestamp_" + - Content: Prefix content with its character length "10:actual_content" + - Embedding: Use custom embedding extraction function + + This tests the flexibility to completely reshape how chunk data maps + to database columns, useful for integrating with existing database schemas + or applying business-specific transformations. + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 20 + common_module = db_config.common_module + + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) + for i in range(num_records) + ] + + chunk_embedding_fn = common_module.chunk_embedding_fn + specs = ( + common_module.ColumnSpecsBuilder().add_custom_column_spec( + common_module.ColumnSpec.text( + column_name="custom_id", + value_fn=lambda chunk: + f"timestamp_{chunk.metadata.get('timestamp', '')}") + ).add_custom_column_spec( + common_module.ColumnSpec.vector( + column_name="embedding_vec", + value_fn=chunk_embedding_fn)).add_custom_column_spec( + common_module.ColumnSpec.text( + column_name="content_col", + value_fn=lambda chunk: + f"{len(chunk.content.text)}:{chunk.content.text}")). + with_metadata_spec().build()) + + def custom_row_to_chunk(row): + timestamp = row.custom_id.split('timestamp_')[1] + i = int(timestamp.split('T')[1][:2]) + + embedding_list = [ + float(x) for x in row.embedding_vec.strip('[]').split(',') + ] + + content = row.content_col.split(':', 1)[1] + + return Chunk( + id=str(i), + content=Content(text=content), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata)) + + writer_config = helper.create_writer_config(helper.custom_table_name, specs) + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.custom_table_name) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)]) + def test_custom_type_conversion_with_default_columns(self, db_config): + """Test custom type conversion and SQL typecasting with modified column + names. + + This test validates data type handling and database-specific SQL features: + + Type conversion: + - Convert string IDs to integers before storage + - Apply length-prefix transformation to content + + SQL typecasting (database-specific): + - PostgreSQL: Use ::text typecast for converted integers + - MySQL: Rely on automatic type conversion (no explicit typecast) + + Column name customization: + - Use custom names but with standard spec builders (not completely custom + functions) + + This tests the ability to adapt data types for database constraints + while maintaining the standard chunk-to-database mapping logic. + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 20 + common_module = db_config.common_module + + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) + for i in range(num_records) + ] + + if db_config.database_type == "postgresql": + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec( + column_name="custom_id", + python_type=int, + convert_fn=lambda x: int(x), + sql_typecast="::text").with_content_spec( + column_name="content_col", + convert_fn=lambda x: f"{len(x)}:{x}" # Add length prefix + ).with_embedding_spec( + column_name="embedding_vec").with_metadata_spec().build()) + else: # MySQL + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec( + column_name="custom_id", + python_type=int, + convert_fn=lambda x: int(x)).with_content_spec( + column_name="content_col", + convert_fn=lambda x: f"{len(x)}:{x}").with_embedding_spec( + column_name="embedding_vec").with_metadata_spec().build()) + + def type_conversion_row_to_chunk(row): + embedding_list = [ + float(x) for x in row.embedding_vec.strip('[]').split(',') + ] + + content = row.content_col.split(':', 1)[1] + + return Chunk( + id=row.custom_id, # custom_id is the converted ID field + content=Content(text=content), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata)) + + writer_config = helper.create_writer_config(helper.custom_table_name, specs) + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.custom_table_name) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(type_conversion_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)]) + def test_default_id_embedding_specs(self, db_config): + """Test minimal schema with only ID and embedding columns. + + This test validates the ability to create a minimal vector database + schema: + - Only stores id and embedding fields + - content and metadata columns are excluded from the table + - Tests that the system correctly handles missing/null fields + + Use case: When you only need vector similarity search without storing + the original content or metadata (perhaps stored elsewhere). + + Validation: + - Chunks written with content/metadata are stored with those fields as + null + - Reading back produces chunks with null content and empty metadata + - Vector similarity operations still work normally + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 20 + common_module = db_config.common_module + + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + build()) + + writer_config = helper.create_writer_config(column_specs=specs) + + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + expected_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records) + for chunk in expected_chunks: + chunk.content.text = None # Content column not included in schema + chunk.metadata = {} # Metadata column not included in schema + + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.default_table_name) + if db_config.database_type == "postgresql": + jdbc_params['query'] = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(embedding AS text) + FROM {helper.default_table_name} + ORDER BY id + """ + elif db_config.database_type == "mysql": + jdbc_params['query'] = f""" + SELECT + CAST(id AS CHAR(255)) as id, + vector_to_string(embedding) as embedding + FROM {helper.default_table_name} + """ + + with self.read_test_pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk) + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)]) + def test_metadata_field_extraction(self, db_config): + """Test extracting specific metadata fields into separate database columns. + + This test validates the ability to: + - Extract specific fields from the JSON metadata object + - Map them to dedicated database columns (e.g., metadata.source -> source + column) + - Apply database-specific SQL typecasts (PostgreSQL ::timestamp vs MySQL + default) + - Store and retrieve the extracted fields correctly + + This is different from default metadata handling which stores the entire + metadata object as JSON in a single column. + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 20 + common_module = db_config.common_module + + if db_config.database_type == "postgresql": + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec( + ).with_content_spec().add_metadata_field( + field="source", + column_name="source", + python_type=str, + sql_typecast=None).add_metadata_field( + field="timestamp", + python_type=str, + sql_typecast="::timestamp").build()) + else: + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec( + ).with_content_spec().add_metadata_field( + field="source", column_name="source", + python_type=str).add_metadata_field( + field="timestamp", python_type=str).build()) + + writer_config = helper.create_writer_config( + helper.metadata_conflicts_table, specs, conflict_resolution=None) + + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={ + "source": f"source_{i % 3}", + "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + def metadata_row_to_chunk(row): + embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')] + timestamp = row.timestamp.replace( + ' ', 'T') if ' ' in row.timestamp else row.timestamp + return Chunk( + id=row.id, + content=Content(text=row.content), + embedding=Embedding(dense_embedding=embedding_list), + metadata={ + "source": row.source, "timestamp": timestamp + }) + + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.metadata_conflicts_table) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)]) + def test_composite_unique_constraint_conflicts(self, db_config): + """Test conflict resolution when unique constraints span multiple columns. + + This test validates conflict resolution when the unique constraint is NOT + on the primary key, but on a combination of other columns (source + + timestamp). + + Scenario: + 1. Insert records with unique (source, timestamp) combinations + 2. Attempt to insert records with same (source, timestamp) but different + IDs and content + 3. Verify that conflict resolution (UPDATE) works correctly based on + composite key + + This is different from test_conflict_resolution which tests conflicts on + the primary key field only. + """ + self.skip_if_dataflow_runner() + + helper = self.db_helpers[db_config.database_type] + num_records = 5 + common_module = db_config.common_module + + if db_config.database_type == "postgresql": + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec( + ).with_content_spec().add_metadata_field( + field="source", + column_name="source", + python_type=str, + sql_typecast=None).add_metadata_field( + field="timestamp", + python_type=str, + sql_typecast="::timestamp").build()) + + conflict_resolution = common_module.ConflictResolution( + on_conflict_fields=["source", "timestamp"], + action="UPDATE", + update_fields=["embedding", "content"]) + elif db_config.database_type == "mysql": + specs = ( + common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec( + ).with_content_spec().add_metadata_field( + field="source", column_name="source", + python_type=str).add_metadata_field( + field="timestamp", python_type=str).build()) + + # MySQL conflict resolution - detects unique constraint automatically + conflict_resolution = common_module.ConflictResolution( + action="UPDATE", update_fields=["embedding", "content"]) + + writer_config = helper.create_writer_config( + helper.metadata_conflicts_table, specs, conflict_resolution) + + initial_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + with self.write_test_pipeline as p: + _ = ( + p | "Write Initial" >> beam.Create(initial_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + conflicting_chunks = [ + Chunk( + id=f"new_{i}", + content=Content(text=f"updated_content_{i}"), + embedding=Embedding( + dense_embedding=[float(i) * 2, float(i + 1) * 2]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + with self.write_test_pipeline2 as p: + _ = ( + p | "Write Conflicts" >> beam.Create(conflicting_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + expected_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"updated_content_{i}"), + embedding=Embedding( + dense_embedding=[float(i) * 2, float(i + 1) * 2]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + def metadata_row_to_chunk(row): + embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')] + timestamp = row.timestamp.replace( + ' ', 'T') if ' ' in row.timestamp else row.timestamp + return Chunk( + id=row.id, + content=Content(text=row.content), + embedding=Embedding(dense_embedding=embedding_list), + metadata={ + "source": row.source, "timestamp": timestamp + }) + + jdbc_params = PipelineVerificationHelper.build_jdbc_params( + helper, helper.metadata_conflicts_table) + + with self.read_test_pipeline as p: + rows = (p | ReadFromJdbc(**jdbc_params)) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk) + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py new file mode 100644 index 000000000000..13ebeacbc866 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py @@ -0,0 +1,220 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from typing import Callable +from typing import List +from typing import NamedTuple +from typing import Optional + +import apache_beam as beam +from apache_beam.coders import registry +from apache_beam.coders.row_coder import RowCoder +from apache_beam.io.jdbc import WriteToJdbc +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec +from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution +from apache_beam.ml.rag.types import Chunk + +_LOGGER = logging.getLogger(__name__) + + +class _MySQLQueryBuilder: + def __init__( + self, + table_name: str, + *, + column_specs: List[ColumnSpec], + conflict_resolution: Optional[ConflictResolution] = None): + """Builds SQL queries for writing Chunks with Embeddings to MySQL. + """ + self.table_name = table_name + + self.column_specs = column_specs + self.conflict_resolution = conflict_resolution + + names = [col.column_name for col in self.column_specs] + duplicates = set(name for name in names if names.count(name) > 1) + if duplicates: + raise ValueError(f"Duplicate column names found: {duplicates}") + + fields = [(col.column_name, col.python_type) for col in self.column_specs] + type_name = f"VectorRecord_{table_name}" + self.record_type = NamedTuple(type_name, fields) # type: ignore + + registry.register_coder(self.record_type, RowCoder) + + # Set default update fields to all non-conflict fields if update fields are + # not specified + if self.conflict_resolution: + self.conflict_resolution.maybe_set_default_update_fields( + [col.column_name for col in self.column_specs if col.column_name]) + + def build_insert(self) -> str: + fields = [col.column_name for col in self.column_specs] + placeholders = [col.placeholder for col in self.column_specs] + + # Build base query + query = f""" + INSERT INTO {self.table_name} + ({', '.join(fields)}) + VALUES ({', '.join(placeholders)}) + """ + + if self.conflict_resolution: + query += f" {self.conflict_resolution.get_conflict_clause()}" + + _LOGGER.info("MySQL Query with placeholders %s", query) + return query + + def create_converter(self) -> Callable[[Chunk], NamedTuple]: + """Creates a function to convert Chunks to records.""" + def convert(chunk: Chunk) -> self.record_type: # type: ignore + return self.record_type( + **{col.column_name: col.value_fn(chunk) + for col in self.column_specs}) # type: ignore + + return convert + + +class MySQLVectorWriterConfig(VectorDatabaseWriteConfig): + def __init__( + self, + connection_config: ConnectionConfig, + table_name: str, + *, + # pylint: disable=dangerous-default-value + write_config: WriteConfig = WriteConfig(), + column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( + ), + conflict_resolution: Optional[ConflictResolution] = None): + """Configuration for writing vectors to MySQL using jdbc. + + Supports flexible schema configuration through column specifications and + conflict resolution strategies with MySQL-specific syntax. + + Args: + connection_config: + :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`. + table_name: Target table name. + write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control + batch sizes, authosharding, etc. + column_specs: + Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how + embeddings and metadata are written to the database + schema. If None, uses default Chunk schema with MySQL vector + functions. + conflict_resolution: Optional + :class:`~.mysql_common.ConflictResolution` + strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE. + None by default, meaning errors are thrown when attempting to insert + duplicates. + + Examples: + Simple case with default schema: + + >>> config = MySQLVectorWriterConfig( + ... connection_config=ConnectionConfig(...), + ... table_name='embeddings' + ... ) + + Custom schema with metadata fields and MySQL functions: + + >>> specs = (ColumnSpecsBuilder() + ... .with_id_spec(column_name="my_id_column") + ... .with_embedding_spec( + ... column_name="embedding_vec", + ... placeholder="string_to_vector(?)" + ... ) + ... .add_metadata_field(field="source", column_name="src") + ... .add_metadata_field( + ... "timestamp", + ... column_name="created_at", + ... placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')" + ... ) + ... .build()) + + Minimal schema (only ID + embedding written): + + >>> column_specs = (ColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .build()) + + >>> config = MySQLVectorWriterConfig( + ... connection_config=ConnectionConfig(...), + ... table_name='embeddings', + ... column_specs=specs, + ... conflict_resolution=ConflictResolution( + ... on_conflict_fields=["id"], + ... action="UPDATE", + ... update_fields=["embedding", "content"] + ... ) + ... ) + + Using MySQL JSON functions: + + >>> specs = (ColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .with_metadata_spec( + ... column_name="metadata_json", + ... placeholder="CAST(? AS JSON)" + ... ) + ... .build()) + """ + self.connection_config = connection_config + self.write_config = write_config + # NamedTuple is created and registered here during pipeline construction + self.query_builder = _MySQLQueryBuilder( + table_name, + column_specs=column_specs, + conflict_resolution=conflict_resolution) + + def create_write_transform(self) -> beam.PTransform: + return _WriteToMySQLVectorDatabase(self) + + +class _WriteToMySQLVectorDatabase(beam.PTransform): + """Implementation of MySQL vector database write.""" + def __init__(self, config: MySQLVectorWriterConfig): + self.config = config + self.query_builder = config.query_builder + self.connection_config = config.connection_config + self.write_config = config.write_config + + def expand(self, pcoll: beam.PCollection[Chunk]): + return ( + pcoll + | + "Convert to Records" >> beam.Map(self.query_builder.create_converter()) + | "Write to MySQL" >> WriteToJdbc( + table_name=self.query_builder.table_name, + driver_class_name="com.mysql.cj.jdbc.Driver", + jdbc_url=self.connection_config.jdbc_url, + username=self.connection_config.username, + password=self.connection_config.password, + statement=self.query_builder.build_insert(), + connection_properties=self.connection_config.connection_properties, + connection_init_sqls=self.connection_config.connection_init_sqls, + autosharding=self.write_config.autosharding, + max_connections=self.write_config.max_connections, + write_batch_size=self.write_config.write_batch_size, + **self.connection_config.additional_jdbc_args)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py new file mode 100644 index 000000000000..983f3f59fa87 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -0,0 +1,454 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Type + +from apache_beam.ml.rag.types import Chunk + + +def chunk_embedding_fn(chunk: Chunk) -> str: + """Convert embedding to MySQL vector string format. + + Formats dense embedding as a MySQL-compatible vector string. + Example: [1.0, 2.0] -> '[1.0,2.0]' + + Args: + chunk: Input Chunk object. + + Returns: + str: MySQL vector string representation of the embedding. + + Raises: + ValueError: If chunk has no dense embedding. + """ + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f'Expected chunk to contain embedding. {chunk}') + return '[' + ','.join(str(x) for x in chunk.embedding.dense_embedding) + ']' + + +@dataclass +class ColumnSpec: + """Specification for mapping Chunk fields to MySQL columns for insertion. + + Defines how to extract and format values from Chunks into MySQL database + columns, handling the full pipeline from Python value to SQL insertion. + + The insertion process works as follows: + - value_fn extracts a value from the Chunk and formats it as needed + - The value is stored in a NamedTuple field with the specified python_type + - During SQL insertion, the value is bound to a ? placeholder + + Attributes: + column_name: The column name in the database table. + python_type: Python type for the NamedTuple field that will hold the + value. Must be compatible with + :class:`~apache_beam.coders.row_coder.RowCoder`. + value_fn: Function to extract and format the value from a Chunk. + Takes a Chunk and returns a value of python_type. + placeholder: Optional placeholder to apply typecasts or functions to + value ? placeholder e.g. "string_to_vector(?)" for vector columns. + + Examples: + + Basic text column (uses standard JDBC type mapping): + + >>> ColumnSpec.text( + ... column_name="content", + ... value_fn=lambda chunk: chunk.content.text + ... ) + ... # Results in: INSERT INTO table (content) VALUES (?) + + Timestamp from metadata: + + >>> ColumnSpec( + ... column_name="created_at", + ... python_type=str, + ... value_fn=lambda chunk: chunk.metadata.get("timestamp") + ... ) + ... # Results in: INSERT INTO table (created_at) VALUES (?) + + + Factory Methods: + text: Creates a text column specification. + integer: Creates an integer column specification. + float: Creates a float column specification. + vector: Creates a vector column specification with string_to_vector(). + json: Creates a JSON column specification. + """ + column_name: str + python_type: Type + value_fn: Callable[[Chunk], Any] + placeholder: str = '?' + + @classmethod + def text( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create a text column specification.""" + return cls(column_name, str, value_fn) + + @classmethod + def integer( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create an integer column specification.""" + return cls(column_name, int, value_fn) + + @classmethod + def float( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create a float column specification.""" + return cls(column_name, float, value_fn) + + @classmethod + def vector( + cls, + column_name: str, + value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec': + """Create a vector column specification with string_to_vector() function.""" + return cls(column_name, str, value_fn, "string_to_vector(?)") + + @classmethod + def json( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create a JSON column specification.""" + return cls(column_name, str, value_fn) + + +class ColumnSpecsBuilder: + """Builder for :class:`.ColumnSpec`'s with chainable methods.""" + def __init__(self): + self._specs: List[ColumnSpec] = [] + + @staticmethod + def with_defaults() -> 'ColumnSpecsBuilder': + """Add all default column specifications.""" + return ( + ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().with_metadata_spec()) + + def with_id_spec( + self, + column_name: str = "id", + python_type: Type = str, + convert_fn: Optional[Callable[[str], + Any]] = None) -> 'ColumnSpecsBuilder': + """Add ID :class:`.ColumnSpec` with optional type and conversion. + + Args: + column_name: Name for the ID column (defaults to "id") + python_type: Python type for the column (defaults to str) + convert_fn: Optional function to convert the chunk ID + If None, uses ID as-is + + Returns: + Self for method chaining + + Example: + >>> builder.with_id_spec( + ... column_name="doc_id", + ... python_type=int, + ... convert_fn=lambda id: int(id.split('_')[1]) + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + value = chunk.id + return convert_fn(value) if convert_fn else value + + self._specs.append( + ColumnSpec( + column_name=column_name, python_type=python_type, + value_fn=value_fn)) + return self + + def with_content_spec( + self, + column_name: str = "content", + python_type: Type = str, + convert_fn: Optional[Callable[[str], + Any]] = None) -> 'ColumnSpecsBuilder': + """Add content :class:`.ColumnSpec` with optional type and conversion. + + Args: + column_name: Name for the content column (defaults to "content") + python_type: Python type for the column (defaults to str) + convert_fn: Optional function to convert the content text + If None, uses content text as-is + + Returns: + Self for method chaining + + Example: + >>> builder.with_content_spec( + ... column_name="content_length", + ... python_type=int, + ... convert_fn=len # Store content length instead of content + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if chunk.content.text is None: + raise ValueError(f'Expected chunk to contain content. {chunk}') + value = chunk.content.text + return convert_fn(value) if convert_fn else value + + self._specs.append( + ColumnSpec( + column_name=column_name, python_type=python_type, + value_fn=value_fn)) + return self + + def with_metadata_spec( + self, + column_name: str = "metadata", + python_type: Type = str, + convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None + ) -> 'ColumnSpecsBuilder': + """Add metadata :class:`.ColumnSpec` with optional type and conversion. + + Args: + column_name: Name for the metadata column (defaults to "metadata") + python_type: Python type for the column (defaults to str) + convert_fn: Optional function to convert the metadata dictionary + If None and python_type is str, converts to JSON string + + Returns: + Self for method chaining + + Example: + >>> builder.with_metadata_spec( + ... column_name="meta_tags", + ... python_type=str, + ... convert_fn=lambda meta: ','.join(meta.keys()) + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if convert_fn: + return convert_fn(chunk.metadata) + return json.dumps( + chunk.metadata) if python_type == str else chunk.metadata + + self._specs.append( + ColumnSpec( + column_name=column_name, python_type=python_type, + value_fn=value_fn)) + return self + + def with_embedding_spec( + self, + column_name: str = "embedding", + convert_fn: Optional[Callable[[List[float]], Any]] = None + ) -> 'ColumnSpecsBuilder': + """Add embedding :class:`.ColumnSpec` with optional conversion. + + Args: + column_name: Name for the embedding column (defaults to "embedding") + convert_fn: Optional function to convert the dense embedding values + If None, uses default MySQL vector format + + Returns: + Self for method chaining + + Example: + >>> builder.with_embedding_spec( + ... column_name="embedding_vector", + ... convert_fn=lambda values: '[' + ','.join(f"{x:.4f}" + ... for x in values) + ']' + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f'Expected chunk to contain embedding. {chunk}') + values = chunk.embedding.dense_embedding + if convert_fn: + return convert_fn(values) + return '[' + ','.join(str(x) for x in values) + ']' + + self._specs.append( + ColumnSpec.vector(column_name=column_name, value_fn=value_fn)) + return self + + def add_metadata_field( + self, + field: str, + python_type: Type, + column_name: Optional[str] = None, + convert_fn: Optional[Callable[[Any], Any]] = None, + default: Any = None) -> 'ColumnSpecsBuilder': + """Add a :class:`.ColumnSpec` that extracts and converts a field from + chunk metadata. + + Args: + field: Key to extract from chunk metadata + python_type: Python type for the column (e.g. str, int, float) + column_name: Name for the column (defaults to metadata field name) + convert_fn: Optional function to convert the extracted value to + desired type. If None, value is used as-is + default: Default value if field is missing from metadata + + Returns: + Self for chaining + + Examples: + Simple string field: + >>> builder.add_metadata_field("source", str) + + Integer with default: + >>> builder.add_metadata_field( + ... field="count", + ... python_type=int, + ... column_name="item_count", + ... default=0 + ... ) + + Float with conversion and default: + >>> builder.add_metadata_field( + ... field="confidence", + ... python_type=float, + ... convert_fn=lambda x: round(float(x), 2), + ... default=0.0 + ... ) + + Timestamp with conversion: + >>> builder.add_metadata_field( + ... field="created_at", + ... python_type=str, + ... convert_fn=lambda ts: ts.replace('T', ' ') + ... ) + """ + name = column_name or field + + def value_fn(chunk: Chunk) -> Any: + value = chunk.metadata.get(field, default) + if value is not None and convert_fn is not None: + value = convert_fn(value) + return value + + spec = ColumnSpec( + column_name=name, python_type=python_type, value_fn=value_fn) + + self._specs.append(spec) + return self + + def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder': + """Add a custom :class:`.ColumnSpec` to the builder. + + Use this method when you need complete control over the + :class:`.ColumnSpec`, including custom value extraction and type handling. + + Args: + spec: A :class:`.ColumnSpec` instance defining the column name, type, + value extraction, and optional MySQL function. + + Returns: + Self for method chaining + + Examples: + Custom text column from chunk metadata: + >>> builder.add_custom_column_spec( + ... ColumnSpec.text( + ... column_name="source_and_id", + ... value_fn=lambda chunk: + ... f"{chunk.metadata.get('source')}_{chunk.id}" + ... ) + ... ) + """ + self._specs.append(spec) + return self + + def build(self) -> List[ColumnSpec]: + """Build the final list of column specifications.""" + return self._specs.copy() + + +@dataclass +class ConflictResolution: + """Specification for how to handle conflicts during insert. + + Configures conflict handling behavior when inserting records that may + violate unique constraints using MySQL's ON DUPLICATE KEY UPDATE syntax. + + MySQL automatically detects conflicts based on PRIMARY KEY or UNIQUE + constraints defined on the table. + + Attributes: + action: How to handle conflicts - either "UPDATE" or "IGNORE". + UPDATE: Updates existing record with new values. + IGNORE: Skips conflicting records (uses no-op update). + update_fields: Optional list of fields to update on conflict. If None, + all fields are updated (for UPDATE action only). + primary_key_field: Required for IGNORE action. The primary key field + name to use for the no-op update. + + Examples: + Update all fields on conflict: + >>> ConflictResolution(action="UPDATE") + + Update specific fields on conflict: + >>> ConflictResolution( + ... action="UPDATE", + ... update_fields=["embedding", "content"] + ... ) + + Ignore conflicts with explicit primary key: + >>> ConflictResolution( + ... action="IGNORE", + ... primary_key_field="id" + ... ) + + Ignore conflicts with custom primary key: + >>> ConflictResolution( + ... action="IGNORE", + ... primary_key_field="custom_id" + ... ) + """ + action: Literal["UPDATE", "IGNORE"] = "UPDATE" + update_fields: Optional[List[str]] = None + primary_key_field: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if self.action == "IGNORE" and self.primary_key_field is None: + raise ValueError("primary_key_field is required when action='IGNORE'") + + def maybe_set_default_update_fields(self, columns: List[str]): + """Set default update fields to all columns if not specified.""" + if self.action != "UPDATE": + return + if self.update_fields is not None: + return + # Default to updating all fields + self.update_fields = columns + + def get_conflict_clause(self) -> str: + """Get MySQL conflict clause using ON DUPLICATE KEY UPDATE syntax.""" + if self.action == "IGNORE": + # Use no-op update with user-specified primary key field + assert self.primary_key_field is not None, \ + "primary_key_field must be set for IGNORE action" + return f"ON DUPLICATE KEY UPDATE {self.primary_key_field} = "\ + f"{self.primary_key_field}" + + # update_fields should be set by query builder before this is called + assert self.update_fields is not None, \ + "update_fields must be set before generating conflict clause" + updates = [f"{field} = VALUES({field})" for field in self.update_fields] + return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}" diff --git a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py index cd30766a2886..366d27164d04 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py @@ -33,6 +33,28 @@ ('metadata', str)]) registry.register_coder(TestRow, RowCoder) +CustomSpecsRow = NamedTuple( + 'CustomSpecsRow', + [ + ('custom_id', str), # For id_spec test + ('embedding_vec', List[float]), # For embedding_spec test + ('content_col', str), # For content_spec test + ('metadata', str) + ]) +registry.register_coder(CustomSpecsRow, RowCoder) + +MetadataConflictRow = NamedTuple( + 'MetadataConflictRow', + [ + ('id', str), + ('source', str), # For metadata_spec and composite key + ('timestamp', str), # For metadata_spec and composite key + ('content', str), + ('embedding', List[float]), + ('metadata', str) + ]) +registry.register_coder(MetadataConflictRow, RowCoder) + VECTOR_SIZE = 768 From a829328c3439892da320b0bd60174319de684c95 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 23 Jun 2025 09:20:31 -0400 Subject: [PATCH 2/8] Trigger tests again. --- .../trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index d6a91b7e2e86..38ae1cf68222 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 7 + "modification": 8 } From 5f602dbc4b4b67208b40dd46937dff6f44314805 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 1 Jul 2025 12:09:41 -0400 Subject: [PATCH 3/8] Comments. --- .../apache_beam/ml/rag/ingestion/mysql.py | 64 ++++++++++++++++--- .../ml/rag/ingestion/mysql_common.py | 24 ------- .../ml/rag/ingestion/test_utils.py | 31 --------- 3 files changed, 55 insertions(+), 64 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py index 13ebeacbc866..f485417bc9ac 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py @@ -16,6 +16,8 @@ # import logging +from abc import ABC +from abc import abstractmethod from typing import Callable from typing import List from typing import NamedTuple @@ -36,6 +38,56 @@ _LOGGER = logging.getLogger(__name__) +class _ConflictResolutionStrategy(ABC): + """Abstract base class for conflict resolution strategies.""" + @abstractmethod + def get_conflict_clause(self, all_columns: List[str]) -> str: + """Generate the MySQL conflict clause.""" + pass + + +class _NoConflictStrategy(_ConflictResolutionStrategy): + """Strategy for when no conflict resolution is needed.""" + def get_conflict_clause(self, all_columns: List[str]) -> str: + return "" + + +class _UpdateStrategy(_ConflictResolutionStrategy): + """Strategy for UPDATE action on conflict.""" + def __init__(self, update_fields: Optional[List[str]] = None): + self.update_fields = update_fields + + def get_conflict_clause(self, all_columns: List[str]) -> str: + # Use provided fields or default to all columns + fields_to_update = self.update_fields or all_columns + assert len(fields_to_update) > 0 + + updates = [f"{field} = VALUES({field})" for field in fields_to_update] + return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}" + + +class _IgnoreStrategy(_ConflictResolutionStrategy): + """Strategy for IGNORE action on conflict.""" + def __init__(self, primary_key_field: str): + self.primary_key_field = primary_key_field + + def get_conflict_clause(self, all_columns: List[str]) -> str: + return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\ + f" = {self.primary_key_field}" + + +def _create_conflict_strategy( + conflict_resolution: Optional[ConflictResolution] +) -> _ConflictResolutionStrategy: + if conflict_resolution is None: + return _NoConflictStrategy + if conflict_resolution.action == "UPDATE": + return _UpdateStrategy(conflict_resolution.update_fields) + if conflict_resolution.action == "IGNORE": + return _IgnoreStrategy(conflict_resolution.primary_key_field) + raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}") + + class _MySQLQueryBuilder: def __init__( self, @@ -48,7 +100,8 @@ def __init__( self.table_name = table_name self.column_specs = column_specs - self.conflict_resolution = conflict_resolution + self.conflict_resolution_strategy = _create_conflict_strategy( + conflict_resolution) names = [col.column_name for col in self.column_specs] duplicates = set(name for name in names if names.count(name) > 1) @@ -61,12 +114,6 @@ def __init__( registry.register_coder(self.record_type, RowCoder) - # Set default update fields to all non-conflict fields if update fields are - # not specified - if self.conflict_resolution: - self.conflict_resolution.maybe_set_default_update_fields( - [col.column_name for col in self.column_specs if col.column_name]) - def build_insert(self) -> str: fields = [col.column_name for col in self.column_specs] placeholders = [col.placeholder for col in self.column_specs] @@ -78,8 +125,7 @@ def build_insert(self) -> str: VALUES ({', '.join(placeholders)}) """ - if self.conflict_resolution: - query += f" {self.conflict_resolution.get_conflict_clause()}" + query += f" {self.conflict_resolution.get_conflict_clause(fields)}" _LOGGER.info("MySQL Query with placeholders %s", query) return query diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py index 983f3f59fa87..930b97946848 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -428,27 +428,3 @@ def __post_init__(self): """Validate configuration after initialization.""" if self.action == "IGNORE" and self.primary_key_field is None: raise ValueError("primary_key_field is required when action='IGNORE'") - - def maybe_set_default_update_fields(self, columns: List[str]): - """Set default update fields to all columns if not specified.""" - if self.action != "UPDATE": - return - if self.update_fields is not None: - return - # Default to updating all fields - self.update_fields = columns - - def get_conflict_clause(self) -> str: - """Get MySQL conflict clause using ON DUPLICATE KEY UPDATE syntax.""" - if self.action == "IGNORE": - # Use no-op update with user-specified primary key field - assert self.primary_key_field is not None, \ - "primary_key_field must be set for IGNORE action" - return f"ON DUPLICATE KEY UPDATE {self.primary_key_field} = "\ - f"{self.primary_key_field}" - - # update_fields should be set by query builder before this is called - assert self.update_fields is not None, \ - "update_fields must be set before generating conflict clause" - updates = [f"{field} = VALUES({field})" for field in self.update_fields] - return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}" diff --git a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py index 366d27164d04..0373874c09d2 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py @@ -18,43 +18,12 @@ import hashlib import json from typing import List -from typing import NamedTuple import apache_beam as beam -from apache_beam.coders import registry -from apache_beam.coders.row_coder import RowCoder from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Content from apache_beam.ml.rag.types import Embedding -TestRow = NamedTuple( - 'TestRow', - [('id', str), ('embedding', List[float]), ('content', str), - ('metadata', str)]) -registry.register_coder(TestRow, RowCoder) - -CustomSpecsRow = NamedTuple( - 'CustomSpecsRow', - [ - ('custom_id', str), # For id_spec test - ('embedding_vec', List[float]), # For embedding_spec test - ('content_col', str), # For content_spec test - ('metadata', str) - ]) -registry.register_coder(CustomSpecsRow, RowCoder) - -MetadataConflictRow = NamedTuple( - 'MetadataConflictRow', - [ - ('id', str), - ('source', str), # For metadata_spec and composite key - ('timestamp', str), # For metadata_spec and composite key - ('content', str), - ('embedding', List[float]), - ('metadata', str) - ]) -registry.register_coder(MetadataConflictRow, RowCoder) - VECTOR_SIZE = 768 From 566bba9f0fe35dbdb0733881dafc3103f3e6c5a9 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 1 Jul 2025 18:34:37 +0000 Subject: [PATCH 4/8] Fix lints etc. --- sdks/python/apache_beam/ml/rag/ingestion/mysql.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py index f485417bc9ac..c64c083b6c9c 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py @@ -80,10 +80,11 @@ def _create_conflict_strategy( conflict_resolution: Optional[ConflictResolution] ) -> _ConflictResolutionStrategy: if conflict_resolution is None: - return _NoConflictStrategy + return _NoConflictStrategy() if conflict_resolution.action == "UPDATE": return _UpdateStrategy(conflict_resolution.update_fields) if conflict_resolution.action == "IGNORE": + assert conflict_resolution.primary_key_field is not None return _IgnoreStrategy(conflict_resolution.primary_key_field) raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}") @@ -124,8 +125,9 @@ def build_insert(self) -> str: ({', '.join(fields)}) VALUES ({', '.join(placeholders)}) """ - - query += f" {self.conflict_resolution.get_conflict_clause(fields)}" + conflict_clause = self.conflict_resolution_strategy.get_conflict_clause( + all_columns=fields) + query += f" {conflict_clause}" _LOGGER.info("MySQL Query with placeholders %s", query) return query From 15a4d3dc8aea1b1fadcd8b35413d4a1d8e84b46b Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 1 Jul 2025 19:59:28 +0000 Subject: [PATCH 5/8] Comment. --- sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py index 930b97946848..6eab4fdc3a9f 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -255,7 +255,10 @@ def value_fn(chunk: Chunk) -> Any: def with_embedding_spec( self, column_name: str = "embedding", - convert_fn: Optional[Callable[[List[float]], Any]] = None + convert_fn: Callable[ + [List[float]], + Any] = lambda embeddig: '[' + ','.join(str(x) + for x in embedding) + ']' ) -> 'ColumnSpecsBuilder': """Add embedding :class:`.ColumnSpec` with optional conversion. @@ -278,9 +281,7 @@ def value_fn(chunk: Chunk) -> Any: if chunk.embedding is None or chunk.embedding.dense_embedding is None: raise ValueError(f'Expected chunk to contain embedding. {chunk}') values = chunk.embedding.dense_embedding - if convert_fn: - return convert_fn(values) - return '[' + ','.join(str(x) for x in values) + ']' + return convert_fn(values) self._specs.append( ColumnSpec.vector(column_name=column_name, value_fn=value_fn)) From ed23b532f4ca6af8af5db00b358552cffe20ac86 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 2 Jul 2025 09:52:56 -0400 Subject: [PATCH 6/8] Fix typo --- sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py index 6eab4fdc3a9f..6bb2ef8dc122 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -257,8 +257,8 @@ def with_embedding_spec( column_name: str = "embedding", convert_fn: Callable[ [List[float]], - Any] = lambda embeddig: '[' + ','.join(str(x) - for x in embedding) + ']' + Any] = lambda embedding: '[' + ','.join(str(x) + for x in embedding) + ']' ) -> 'ColumnSpecsBuilder': """Add embedding :class:`.ColumnSpec` with optional conversion. From ab82c6042306c729b6f4b1e9fb3e9d74d94f4f69 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 2 Jul 2025 16:54:06 +0000 Subject: [PATCH 7/8] Lint fix. --- .../apache_beam/ml/rag/ingestion/mysql_common.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py index 6bb2ef8dc122..985fdb57117d 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -134,6 +134,11 @@ def json( return cls(column_name, str, value_fn) +def embedding_to_string(embedding: List[Float]) -> str: + """Convert embedding to MySQL vector string format.""" + return '[' + ','.join(str(x) for x in embedding) + ']' + + class ColumnSpecsBuilder: """Builder for :class:`.ColumnSpec`'s with chainable methods.""" def __init__(self): @@ -255,10 +260,7 @@ def value_fn(chunk: Chunk) -> Any: def with_embedding_spec( self, column_name: str = "embedding", - convert_fn: Callable[ - [List[float]], - Any] = lambda embedding: '[' + ','.join(str(x) - for x in embedding) + ']' + convert_fn: Callable[[List[float]], Any] = embedding_to_string ) -> 'ColumnSpecsBuilder': """Add embedding :class:`.ColumnSpec` with optional conversion. From 97b7381000e44492b12d6ae217c6ed03f0b68c32 Mon Sep 17 00:00:00 2001 From: claudevdm <33973061+claudevdm@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:14:43 -0400 Subject: [PATCH 8/8] Update sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py Co-authored-by: Danny McCormick --- sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py index 985fdb57117d..c1ee703a5f2e 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -134,7 +134,7 @@ def json( return cls(column_name, str, value_fn) -def embedding_to_string(embedding: List[Float]) -> str: +def embedding_to_string(embedding: List[float]) -> str: """Convert embedding to MySQL vector string format.""" return '[' + ','.join(str(x) for x in embedding) + ']'