Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,42 @@

import apache_beam as beam
from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.rag.test_utils import TestHelpers
from apache_beam.ml.rag.types import Chunk, Content, Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import assert_that, equal_to

# pylint: disable=unused-import
try:
from sentence_transformers import SentenceTransformer

SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


@pytest.mark.uses_transformers
@unittest.skipIf(
not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available")
class HuggingfaceTextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
self.artifact_location = tempfile.mkdtemp(prefix="sentence_transformers_")
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
},
),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
},
),
]

def tearDown(self) -> None:
Expand All @@ -85,14 +73,16 @@ def test_embedding_pipeline(self):
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
content=Content(text="This is a test sentence."),
),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 384),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
content=Content(text="Another example."),
),
]
embedder = HuggingfaceTextEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2")
Expand All @@ -105,8 +95,10 @@ def test_embedding_pipeline(self):
with_transform(embedder))

assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
embeddings,
equal_to(expected, equals_fn=TestHelpers.chunk_approximately_equals),
)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
80 changes: 80 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/open_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RAG-specific embedding implementations using OpenAI models."""

from typing import Optional

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.open_ai import _OpenAITextEmbeddingHandler

__all__ = ['OpenAITextEmbeddings']


class OpenAITextEmbeddings(EmbeddingsManager):
def __init__(
self,
model_name: str,
*,
api_key: Optional[str] = None,
organization: Optional[str] = None,
dimensions: Optional[int] = None,
user: Optional[str] = None,
max_batch_size: Optional[int] = None,
**kwargs):
"""Utilizes OpenAI text embeddings for semantic search and RAG pipelines.

Args:
model_name: Name of the OpenAI embedding model
api_key: OpenAI API key
organization: OpenAI organization ID
dimensions: Specific embedding dimensions to use (if supported)
user: End-user identifier for tracking and rate limit calculations
max_batch_size: Maximum batch size for requests to OpenAI API
**kwargs: Additional arguments passed to EmbeddingsManager including
ModelHandler inference_args.
"""
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
self.model_name = model_name
self.api_key = api_key
self.organization = organization
self.dimensions = dimensions
self.user = user
self.max_batch_size = max_batch_size

def get_model_handler(self):
"""Returns model handler configured with RAG adapter."""
return _OpenAITextEmbeddingHandler(
model_name=self.model_name,
api_key=self.api_key,
organization=self.organization,
dimensions=self.dimensions,
user=self.user,
max_batch_size=self.max_batch_size,
)

def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(Chunk)
127 changes: 127 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/open_ai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest

import apache_beam as beam
from apache_beam.ml.rag.embeddings.open_ai import OpenAITextEmbeddings
from apache_beam.ml.rag.test_utils import TestHelpers
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to


@unittest.skipIf(
not os.environ.get('OPENAI_API_KEY'),
'OPENAI_API_KEY environment variable is not set')
class OpenAITextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='openai_')
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
]

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_embedding_pipeline(self):
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 1536),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 1536),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]

embedder = OpenAITextEmbeddings(
model_name="text-embedding-3-small",
dimensions=1536,
api_key=os.environ.get("OPENAI_API_KEY"))

with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))

assert_that(
embeddings,
equal_to(expected, equals_fn=TestHelpers.chunk_approximately_equals))

def test_embedding_pipeline_with_dimensions(self):
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 512),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 512),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]

embedder = OpenAITextEmbeddings(
model_name="text-embedding-3-small",
dimensions=512,
api_key=os.environ.get("OPENAI_API_KEY"))

with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))

assert_that(
embeddings,
equal_to(expected, equals_fn=TestHelpers.chunk_approximately_equals))


if __name__ == '__main__':
unittest.main()
17 changes: 3 additions & 14 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import unittest

import apache_beam as beam
from apache_beam.ml.rag.test_utils import TestHelpers
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
Expand All @@ -39,19 +40,6 @@
VERTEX_AI_AVAILABLE = False


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAITextEmbeddingsTest(unittest.TestCase):
Expand Down Expand Up @@ -104,7 +92,8 @@ def test_embedding_pipeline(self):
with_transform(embedder))

assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
embeddings,
equal_to(expected, equals_fn=TestHelpers.chunk_approximately_equals))


if __name__ == '__main__':
Expand Down
17 changes: 13 additions & 4 deletions sdks/python/apache_beam/ml/rag/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ def find_free_port():
# Return the port number assigned by OS.
return s.getsockname()[1]

@staticmethod
def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


class CustomMilvusContainer(MilvusContainer):
"""Custom Milvus container with configurable ports and environment setup.
Expand Down Expand Up @@ -407,7 +420,3 @@ def assert_chunks_equivalent(
# Validate field metadata.
err_msg = f"Field Metadata doesn't match for chunk {actual.id}"
assert a_f['metadata'] == e_f['metadata'], err_msg


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def request(
"input": batch,
}
if self.dimensions:
kwargs["dimensions"] = [str(self.dimensions)]
kwargs["dimensions"] = self.dimensions
if self.user:
kwargs["user"] = self.user

Expand Down
Loading
Loading