diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml index 4c3a0a98194c7..e8a221be9cd42 100644 --- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml @@ -48,6 +48,7 @@ body: - celery - cloudant - cncf-kubernetes + - cohere - common-io - common-sql - daskexecutor diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index a3fe3386eb3b0..15783d1c9cd91 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -671,9 +671,9 @@ aiobotocore, airbyte, alibaba, all, all_dbs, amazon, apache.atlas, apache.beam, apache.drill, apache.druid, apache.flink, apache.hdfs, apache.hive, apache.impala, apache.kafka, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs, apprise, arangodb, asana, async, atlas, atlassian.jira, aws, azure, cassandra, celery, cgroups, -cloudant, cncf.kubernetes, common.io, common.sql, crypto, dask, daskexecutor, databricks, datadog, -dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, -doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, +cloudant, cncf.kubernetes, cohere, common.io, common.sql, crypto, dask, daskexecutor, databricks, +datadog, dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, +doc, doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, openlineage, opensearch, opsgenie, diff --git a/INSTALL b/INSTALL index b76af9d62200d..6e233fe0a0556 100644 --- a/INSTALL +++ b/INSTALL @@ -98,9 +98,9 @@ aiobotocore, airbyte, alibaba, all, all_dbs, amazon, apache.atlas, apache.beam, apache.drill, apache.druid, apache.flink, apache.hdfs, apache.hive, apache.impala, apache.kafka, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs, apprise, arangodb, asana, async, atlas, atlassian.jira, aws, azure, cassandra, celery, cgroups, -cloudant, cncf.kubernetes, common.io, common.sql, crypto, dask, daskexecutor, databricks, datadog, -dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, -doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, +cloudant, cncf.kubernetes, cohere, common.io, common.sql, crypto, dask, daskexecutor, databricks, +datadog, dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, +doc, doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, openlineage, opensearch, opsgenie, diff --git a/airflow/providers/cohere/.latest-doc-only-change.txt b/airflow/providers/cohere/.latest-doc-only-change.txt new file mode 100644 index 0000000000000..14a9c1ef7c65f --- /dev/null +++ b/airflow/providers/cohere/.latest-doc-only-change.txt @@ -0,0 +1 @@ +c645d8e40c167ea1f6c332cdc3ea0ca5a9363205 diff --git a/airflow/providers/cohere/CHANGELOG.rst b/airflow/providers/cohere/CHANGELOG.rst new file mode 100644 index 0000000000000..54baa1945c135 --- /dev/null +++ b/airflow/providers/cohere/CHANGELOG.rst @@ -0,0 +1,26 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-cohere`` + +Changelog +--------- + +1.0.0 +..... + +Initial version of the provider. diff --git a/airflow/providers/cohere/__init__.py b/airflow/providers/cohere/__init__.py new file mode 100644 index 0000000000000..c88b314548bd6 --- /dev/null +++ b/airflow/providers/cohere/__init__.py @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE +# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES. +# +# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE +# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/provider_packages` DIRECTORY +# diff --git a/airflow/providers/cohere/hooks/__init__.py b/airflow/providers/cohere/hooks/__init__.py new file mode 100644 index 0000000000000..c88b314548bd6 --- /dev/null +++ b/airflow/providers/cohere/hooks/__init__.py @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE +# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES. +# +# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE +# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/provider_packages` DIRECTORY +# diff --git a/airflow/providers/cohere/hooks/cohere.py b/airflow/providers/cohere/hooks/cohere.py new file mode 100644 index 0000000000000..20c77d58690e9 --- /dev/null +++ b/airflow/providers/cohere/hooks/cohere.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import Any + +import cohere + +from airflow.hooks.base import BaseHook + + +class CohereHook(BaseHook): + """ + Use Cohere Python SDK to interact with Cohere platform. + + .. seealso:: https://docs.cohere.com/docs + + :param conn_id: :ref:`Cohere connection id ` + :param timeout: Request timeout in seconds. + :param max_retries: Maximal number of retries for requests. + """ + + conn_name_attr = "conn_id" + default_conn_name = "cohere_default" + conn_type = "cohere" + hook_name = "Cohere" + + def __init__( + self, + conn_id: str = default_conn_name, + timeout: int | None = None, + max_retries: int | None = None, + ) -> None: + super().__init__() + self.conn_id = conn_id + self.timeout = timeout + self.max_retries = max_retries + + @cached_property + def get_conn(self) -> cohere.Client: + conn = self.get_connection(self.conn_id) + return cohere.Client( + api_key=conn.password, timeout=self.timeout, max_retries=self.max_retries, api_url=conn.host + ) + + def create_embeddings( + self, texts: list[str], model: str = "embed-multilingual-v2.0" + ) -> list[list[float]]: + response = self.get_conn.embed(texts=texts, model=model) + embeddings = response.embeddings + return embeddings + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + return { + "hidden_fields": ["schema", "login", "port", "extra"], + "relabeling": { + "password": "API Key", + }, + } + + def test_connection(self) -> tuple[bool, str]: + try: + self.get_conn.generate("Test", max_tokens=10) + return True, "Connection established" + except Exception as e: + return False, str(e) diff --git a/airflow/providers/cohere/operators/__init__.py b/airflow/providers/cohere/operators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cohere/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cohere/operators/embedding.py b/airflow/providers/cohere/operators/embedding.py new file mode 100644 index 0000000000000..dba95e7e8f661 --- /dev/null +++ b/airflow/providers/cohere/operators/embedding.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.models import BaseOperator +from airflow.providers.cohere.hooks.cohere import CohereHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class CohereEmbeddingOperator(BaseOperator): + """Creates the embedding base by interacting with cohere hosted services. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CohereEmbeddingOperator` + + :param input_text: single string text or list of text items that need to be embedded. + :param conn_id: Optional. The name of the Airflow connection to get connection + information for Cohere. Defaults to "cohere_default". + :param timeout: Timeout in seconds for Cohere API. + :param max_retries: Number of times to retry before failing. + """ + + template_fields: Sequence[str] = ("input_text",) + + def __init__( + self, + input_text: list[str] | str, + conn_id: str = CohereHook.default_conn_name, + timeout: int | None = None, + max_retries: int | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + if isinstance(input_text, str): + input_text = [input_text] + self.conn_id = conn_id + self.input_text = input_text + self.timeout = timeout + self.max_retries = max_retries + + @cached_property + def hook(self) -> CohereHook: + """Return an instance of the CohereHook.""" + return CohereHook(conn_id=self.conn_id, timeout=self.timeout, max_retries=self.max_retries) + + def execute(self, context: Context) -> list[list[float]]: + """Embed texts using Cohere embed services.""" + return self.hook.create_embeddings(self.input_text) diff --git a/airflow/providers/cohere/provider.yaml b/airflow/providers/cohere/provider.yaml new file mode 100644 index 0000000000000..55500e0e760ff --- /dev/null +++ b/airflow/providers/cohere/provider.yaml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-cohere + +name: Cohere + +description: | + `Cohere `__ + +suspended: false + +versions: + - 1.0.0 + +integrations: + - integration-name: Cohere + external-doc-url: https://docs.cohere.com/docs + how-to-guide: + - /docs/apache-airflow-providers-cohere/operators/embedding.rst + tags: [software] + +dependencies: + - apache-airflow>=2.5.0 + - cohere>=4.27 + +hooks: + - integration-name: Cohere + python-modules: + - airflow.providers.cohere.hooks.cohere + +operators: + - integration-name: Cohere + python-modules: + - airflow.providers.cohere.operators.embedding + +connection-types: + - hook-class-name: airflow.providers.cohere.hooks.cohere.CohereHook + connection-type: cohere diff --git a/docs/apache-airflow-providers-cohere/changelog.rst b/docs/apache-airflow-providers-cohere/changelog.rst new file mode 100644 index 0000000000000..be95cf4557aa1 --- /dev/null +++ b/docs/apache-airflow-providers-cohere/changelog.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../../airflow/providers/cohere/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-cohere/commits.rst b/docs/apache-airflow-providers-cohere/commits.rst new file mode 100644 index 0000000000000..f69ac2c0ad76c --- /dev/null +++ b/docs/apache-airflow-providers-cohere/commits.rst @@ -0,0 +1,19 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Package apache-airflow-providers-cohere +------------------------------------------- diff --git a/docs/apache-airflow-providers-cohere/connections.rst b/docs/apache-airflow-providers-cohere/connections.rst new file mode 100644 index 0000000000000..0ab583dfd59e0 --- /dev/null +++ b/docs/apache-airflow-providers-cohere/connections.rst @@ -0,0 +1,37 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/connection:cohere: + +Cohere Connection +======================= + +The `Cohere `__ connection type enables access to Cohere APIs, which we can use to interact with multilingual models. + +Default Connection IDs +---------------------- + +Cohere hooks point to ``cohere_default`` by default. + +Configuring the Connection +-------------------------- + +Password (required) + Specify the password to connect. + +Host (optional) + Specify the API host diff --git a/docs/apache-airflow-providers-cohere/index.rst b/docs/apache-airflow-providers-cohere/index.rst new file mode 100644 index 0000000000000..8bfd1678a3460 --- /dev/null +++ b/docs/apache-airflow-providers-cohere/index.rst @@ -0,0 +1,98 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-cohere`` +====================================== + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Basics + + Home + Changelog + Security + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Guides + + Connection types + Operators + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Resources + + Python API <_api/airflow/providers/cohere/index> + PyPI Repository + Installing from sources + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: System tests + + System Tests <_api/tests/system/providers/cohere/index> + +Package apache-airflow-providers-cohere +----------------------------------------- + +`Cohere `__ + + +Release: 1.0.0 + +Provider package +---------------- + +This is a provider package for ``cohere`` APIs. All classes for this provider package +are in ``airflow.providers.cohere`` python module. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below) +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-cohere`` + + +Requirements +------------ + +The minimum Apache Airflow version supported by this provider package is ``2.5.0``. + +================== ================== +PIP package Version required +================== ================== +``apache-airflow`` ``>=2.5.0`` +``cohere`` ``>=4.27`` +================== ================== + +.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits diff --git a/docs/apache-airflow-providers-cohere/installing-providers-from-sources.rst b/docs/apache-airflow-providers-cohere/installing-providers-from-sources.rst new file mode 100644 index 0000000000000..b4e730f4ff21a --- /dev/null +++ b/docs/apache-airflow-providers-cohere/installing-providers-from-sources.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../exts/includes/installing-providers-from-sources.rst diff --git a/docs/apache-airflow-providers-cohere/operators/embedding.rst b/docs/apache-airflow-providers-cohere/operators/embedding.rst new file mode 100644 index 0000000000000..b765fe8e9d07c --- /dev/null +++ b/docs/apache-airflow-providers-cohere/operators/embedding.rst @@ -0,0 +1,40 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:CohereEmbeddingOperator: + +CohereEmbeddingOperator +======================== + +Use the :class:`~airflow.providers.cohere.operators.embedding.CohereEmbeddingOperator` to +interact with Cohere APIs to create embeddings for a given text. + + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +The CohereEmbeddingOperator requires the ``input_text`` as an input to embedding API. Use the ``conn_id`` parameter to specify the Cohere connection to use to +connect to your account. + +Example Code: +------------- + +.. exampleinclude:: /../../tests/system/providers/cohere/example_cohere_embedding_operator.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_cohere_embedding] + :end-before: [END howto_operator_cohere_embedding] diff --git a/docs/apache-airflow-providers-cohere/security.rst b/docs/apache-airflow-providers-cohere/security.rst new file mode 100644 index 0000000000000..b8f95e6ecfa29 --- /dev/null +++ b/docs/apache-airflow-providers-cohere/security.rst @@ -0,0 +1,19 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../exts/includes/security.rst diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst index dba750908d853..9f7fcc9a03480 100644 --- a/docs/apache-airflow/extra-packages-ref.rst +++ b/docs/apache-airflow/extra-packages-ref.rst @@ -172,6 +172,8 @@ These are extras that add dependencies needed for integration with external serv +---------------------+-----------------------------------------------------+-----------------------------------------------------+ | cloudant | ``pip install 'apache-airflow[cloudant]'`` | Cloudant hook | +---------------------+-----------------------------------------------------+-----------------------------------------------------+ +| cohere | ``pip install 'apache-airflow[cohere]'`` | Cohere hook and operators | ++---------------------+-----------------------------------------------------+-----------------------------------------------------+ | databricks | ``pip install 'apache-airflow[databricks]'`` | Databricks hooks and operators | +---------------------+-----------------------------------------------------+-----------------------------------------------------+ | datadog | ``pip install 'apache-airflow[datadog]'`` | Datadog hooks and sensors | diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 38d2259dab3de..8be76a03e9d9e 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -269,6 +269,14 @@ "cross-providers-deps": [], "excluded-python-versions": [] }, + "cohere": { + "deps": [ + "apache-airflow>=2.5.0", + "cohere>=4.27" + ], + "cross-providers-deps": [], + "excluded-python-versions": [] + }, "common.io": { "deps": [ "apache-airflow>=2.8.0" diff --git a/images/breeze/output-commands-hash.txt b/images/breeze/output-commands-hash.txt index 64ba323baa4ad..a5d9a89c8d328 100644 --- a/images/breeze/output-commands-hash.txt +++ b/images/breeze/output-commands-hash.txt @@ -2,7 +2,7 @@ # Please do not solve it but run `breeze setup regenerate-command-images`. # This command should fix the conflict and regenerate help images that you have conflict with. main:96b4884054753db922cb8ca2cc555368 -build-docs:2e9882744f219e56726548ce2d13c3f5 +build-docs:3863424b58f31d5a2b285055b044ad70 ci:find-backtracking-candidates:17fe56b867a745e5032a08dfcd3f73ee ci:fix-ownership:3e5a73533cc96045e72cb258783cfc96 ci:free-space:49af17b032039c05c41a7a8283f365cc @@ -36,26 +36,26 @@ prod-image:build:1628f7bff3e7e369f0358a646682e674 prod-image:pull:3817ef211b023b76df84ee1110ef64dd prod-image:verify:bd2b78738a7c388dbad6076c41a9f906 prod-image:6011405076eb0e1049d87e971e3adce1 -release-management:add-back-references:51960e2831d0e03a2b127d252929b843 +release-management:add-back-references:8154089f6b96923edd772c5a38f7d143 release-management:create-minor-branch:a3834afc4aa5d1e98002c9e9e7a9931d release-management:generate-constraints:01aef235b11e59ed7f10c970a5cdaba7 -release-management:generate-issue-content-providers:cda108e7f2506c2816af8f2a6c24070c +release-management:generate-issue-content-providers:97e29b10f93a0d0276469fc470110c95 release-management:generate-providers-metadata:d4e8e5cfaa024e3963af02d7a873048d release-management:install-provider-packages:34c38aca17d23dbb454fe7a6bfd8e630 release-management:prepare-airflow-package:85d01c57e5b5ee0fb9e5f9d9706ed3b5 -release-management:prepare-provider-documentation:eb861d68b8d72cd98dc8732fc5393796 -release-management:prepare-provider-packages:908e2c826f7b4959dfd8bc693f3857a7 -release-management:publish-docs:51ee9bf1268529513996a14bd5350c19 +release-management:prepare-provider-documentation:8a142fef4a148279d7794571c4cfed33 +release-management:prepare-provider-packages:10c9e86ce1ba6c50acfd0d5367d06da8 +release-management:publish-docs:7eef8e54cbc9743afb865531ff92503d release-management:release-prod-images:cfbfe8b19fee91fd90718f98ef2fd078 release-management:start-rc-process:b27bd524dd3c89f50a747b60a7e892c1 release-management:start-release:419f48f6a4ff4457cb9de7ff496aebbe release-management:update-constraints:02ec4b119150e3fdbac52026e94820ef release-management:verify-provider-packages:96dce5644aad6b37080acf77b3d8de3a -release-management:59d956e45fccf55e47f16e33cfc5d04a +release-management:c0d470f72e53df330e573c7bc0a08470 sbom:build-all-airflow-images:32f8acade299c2b112e986bae99846db -sbom:generate-providers-requirements:3926848718283cf2ef00310a0892e867 +sbom:generate-providers-requirements:bc008fcd52f258fbd989a42e66e44349 sbom:update-sbom-information:653be48be70b4b7ff5172d491aadc694 -sbom:386048e0c00c0de30cf181eb9f3862ea +sbom:9d0e848adefac54a4ccda950ca3bc4ee setup:autocomplete:fffcd49e102e09ccd69b3841a9e3ea8e setup:check-all-params-in-groups:5c5e3c382fc8ce84899d224448b3f48a setup:config:3435f1f1535a82c30591dbf577294d2e diff --git a/images/breeze/output_build-docs.svg b/images/breeze/output_build-docs.svg index 6bb4ceffe65a2..7622be26f3eb2 100644 --- a/images/breeze/output_build-docs.svg +++ b/images/breeze/output_build-docs.svg @@ -163,7 +163,7 @@ [OPTIONS] [all-providers | providers-index | apache-airflow | docker-stack | helm-chart | airbyte | alibaba | amazon | apache.beam | apache.cassandra | apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive |              apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop |  -apprise | arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | common.io | common.sql |           +apprise | arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql |  daskexecutor | databricks | datadog | dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook |    ftp | github | google | grpc | hashicorp | http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie |   diff --git a/images/breeze/output_release-management_add-back-references.svg b/images/breeze/output_release-management_add-back-references.svg index ebcd1572ef1b6..fd0739906cc3b 100644 --- a/images/breeze/output_release-management_add-back-references.svg +++ b/images/breeze/output_release-management_add-back-references.svg @@ -134,10 +134,10 @@ [OPTIONS] [all-providers | apache-airflow | docker-stack | helm-chart | airbyte | alibaba | amazon | apache.beam |     apache.cassandra | apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.impala |            apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop | apprise |        -arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | common.io | common.sql | daskexecutor |      -databricks | datadog | dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github |    -google | grpc | hashicorp | http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql |              -microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie |     +arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql |            +daskexecutor | databricks | datadog | dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook |    +ftp | github | google | grpc | hashicorp | http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql +microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie |   oracle | pagerduty | papermill | plexus | postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex |      zendesk]...                                                                                                            diff --git a/images/breeze/output_release-management_generate-issue-content-providers.svg b/images/breeze/output_release-management_generate-issue-content-providers.svg index 10139cafc51e5..6f03ffa40db55 100644 --- a/images/breeze/output_release-management_generate-issue-content-providers.svg +++ b/images/breeze/output_release-management_generate-issue-content-providers.svg @@ -144,12 +144,12 @@ [OPTIONS] [apache-airflow | docker-stack | helm-chart | airbyte | alibaba | amazon | apache.beam | apache.cassandra |  apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop | apprise | arangodb | asana | atlassian.jira |  -celery | cloudant | cncf.kubernetes | common.io | common.sql | daskexecutor | databricks | datadog | dbt.cloud |       -dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp | http |    -imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo |      -mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | plexus |      -postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |  -sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex | zendesk]...                                   +celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql | daskexecutor | databricks | datadog |          +dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp +http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm |     +mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill |       +plexus | postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp |     +snowflake | sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex | zendesk]...                       Generates content for issue to test the release. diff --git a/images/breeze/output_release-management_prepare-provider-documentation.svg b/images/breeze/output_release-management_prepare-provider-documentation.svg index a8e4587c4b15c..27f198d97e31d 100644 --- a/images/breeze/output_release-management_prepare-provider-documentation.svg +++ b/images/breeze/output_release-management_prepare-provider-documentation.svg @@ -156,12 +156,12 @@ [OPTIONS] [apache-airflow | docker-stack | helm-chart | airbyte | alibaba | amazon | apache.beam | apache.cassandra |  apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop | apprise | arangodb | asana | atlassian.jira |  -celery | cloudant | cncf.kubernetes | common.io | common.sql | daskexecutor | databricks | datadog | dbt.cloud |       -dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp | http |    -imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo |      -mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | plexus |      -postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |  -sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex | zendesk]...                                   +celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql | daskexecutor | databricks | datadog |          +dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp +http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm |     +mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill |       +plexus | postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp |     +snowflake | sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex | zendesk]...                       Prepare CHANGELOG, README and COMMITS information for providers. diff --git a/images/breeze/output_release-management_prepare-provider-packages.svg b/images/breeze/output_release-management_prepare-provider-packages.svg index fa4c36063edd9..507cd13783471 100644 --- a/images/breeze/output_release-management_prepare-provider-packages.svg +++ b/images/breeze/output_release-management_prepare-provider-packages.svg @@ -141,12 +141,12 @@ [OPTIONS] [apache-airflow | docker-stack | helm-chart | airbyte | alibaba | amazon | apache.beam | apache.cassandra |  apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop | apprise | arangodb | asana | atlassian.jira |  -celery | cloudant | cncf.kubernetes | common.io | common.sql | daskexecutor | databricks | datadog | dbt.cloud |       -dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp | http |    -imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo |      -mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | plexus |      -postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |  -sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex | zendesk]...                                   +celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql | daskexecutor | databricks | datadog |          +dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp +http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm |     +mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill |       +plexus | postgres | presto | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp |     +snowflake | sqlite | ssh | tableau | tabular | telegram | trino | vertica | yandex | zendesk]...                       Prepare sdist/whl packages of Airflow Providers. diff --git a/images/breeze/output_release-management_publish-docs.svg b/images/breeze/output_release-management_publish-docs.svg index 8108ee78d10a7..38ffcaf32d39f 100644 --- a/images/breeze/output_release-management_publish-docs.svg +++ b/images/breeze/output_release-management_publish-docs.svg @@ -180,7 +180,7 @@ [OPTIONS] [all-providers | providers-index | apache-airflow | docker-stack | helm-chart | airbyte | alibaba | amazon | apache.beam | apache.cassandra | apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive |              apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop |  -apprise | arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | common.io | common.sql |           +apprise | arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql |  daskexecutor | databricks | datadog | dbt.cloud | dingding | discord | docker | elasticsearch | exasol | facebook |    ftp | github | google | grpc | hashicorp | http | imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie |   diff --git a/images/breeze/output_sbom_generate-providers-requirements.svg b/images/breeze/output_sbom_generate-providers-requirements.svg index 9f08035158704..6dfe112fdf2eb 100644 --- a/images/breeze/output_sbom_generate-providers-requirements.svg +++ b/images/breeze/output_sbom_generate-providers-requirements.svg @@ -186,14 +186,14 @@ (airbyte | alibaba | amazon | apache.beam | apache.cassandra | apache.drill | apache.druid |   apache.flink | apache.hdfs | apache.hive | apache.impala | apache.kafka | apache.kylin |       apache.livy | apache.pig | apache.pinot | apache.spark | apache.sqoop | apprise | arangodb |   -asana | atlassian.jira | celery | cloudant | cncf.kubernetes | common.io | common.sql |        -daskexecutor | databricks | datadog | dbt.cloud | dingding | discord | docker | elasticsearch  -| exasol | facebook | ftp | github | google | grpc | hashicorp | http | imap | influxdb | jdbc -| jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo |     -mysql | neo4j | odbc | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty |   -papermill | plexus | postgres | presto | redis | salesforce | samba | segment | sendgrid |     -sftp | singularity | slack | smtp | snowflake | sqlite | ssh | tableau | tabular | telegram |  -trino | vertica | yandex | zendesk)                                                            +asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.io | common.sql +| daskexecutor | databricks | datadog | dbt.cloud | dingding | discord | docker |              +elasticsearch | exasol | facebook | ftp | github | google | grpc | hashicorp | http | imap |   +influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp |               +microsoft.winrm | mongo | mysql | neo4j | odbc | openfaas | openlineage | opensearch |         +opsgenie | oracle | pagerduty | papermill | plexus | postgres | presto | redis | salesforce |  +samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh |    +tableau | tabular | telegram | trino | vertica | yandex | zendesk)                             --provider-versionProvider version to generate the requirements for i.e `2.1.0`. `latest` is also a supported    value to account for the most recent version of the provider                                   (TEXT)                                                                                         diff --git a/tests/providers/cohere/__init__.py b/tests/providers/cohere/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cohere/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cohere/hooks/__init__.py b/tests/providers/cohere/hooks/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cohere/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cohere/hooks/test_cohere.py b/tests/providers/cohere/hooks/test_cohere.py new file mode 100644 index 0000000000000..8f566ec0c6405 --- /dev/null +++ b/tests/providers/cohere/hooks/test_cohere.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +from airflow.models import Connection +from airflow.providers.cohere.hooks.cohere import ( + CohereHook, +) + + +class TestCohereHook: + """ + Test for CohereHook + """ + + def test__get_api_key(self): + api_key = "test" + api_url = "http://some_host.com" + timeout = 150 + max_retries = 5 + with patch.object( + CohereHook, + "get_connection", + return_value=Connection(conn_type="cohere", password=api_key, host=api_url), + ), patch("cohere.Client") as client: + hook = CohereHook(timeout=timeout, max_retries=max_retries) + _ = hook.get_conn + client.assert_called_once_with( + api_key=api_key, timeout=timeout, max_retries=max_retries, api_url=api_url + ) diff --git a/tests/providers/cohere/operators/__init__.py b/tests/providers/cohere/operators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cohere/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cohere/operators/test_embedding.py b/tests/providers/cohere/operators/test_embedding.py new file mode 100644 index 0000000000000..32dd83aa2614a --- /dev/null +++ b/tests/providers/cohere/operators/test_embedding.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from airflow.models import Connection +from airflow.providers.cohere.operators.embedding import CohereEmbeddingOperator + + +@patch("airflow.providers.cohere.hooks.cohere.CohereHook.get_connection") +@patch("cohere.Client") +def test_cohere_embedding_operator(cohere_client, get_connection): + """ + Test Cohere client is getting called with the correct key and that + the execute methods returns expected response. + """ + embedded_obj = [1, 2, 3] + + class resp: + embeddings = embedded_obj + + api_key = "test" + api_url = "http://some_host.com" + timeout = 150 + max_retries = 5 + texts = ["On Kernel-Target Alignment. We describe a family of global optimization procedures"] + + get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=api_url) + client_obj = MagicMock() + cohere_client.return_value = client_obj + client_obj.embed.return_value = resp + + op = CohereEmbeddingOperator( + task_id="embed", conn_id="some_conn", input_text=texts, timeout=timeout, max_retries=max_retries + ) + + val = op.execute(context={}) + cohere_client.assert_called_once_with( + api_key=api_key, api_url=api_url, timeout=timeout, max_retries=max_retries + ) + assert val == embedded_obj diff --git a/tests/system/providers/cohere/__init__.py b/tests/system/providers/cohere/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/system/providers/cohere/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/system/providers/cohere/example_cohere_embedding_operator.py b/tests/system/providers/cohere/example_cohere_embedding_operator.py new file mode 100644 index 0000000000000..ec97ee91e57cb --- /dev/null +++ b/tests/system/providers/cohere/example_cohere_embedding_operator.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.providers.cohere.operators.embedding import CohereEmbeddingOperator + +with DAG("example_cohere_embedding", schedule=None, start_date=datetime(2023, 1, 1), catchup=False) as dag: + # [START howto_operator_cohere_embedding] + texts = [ + "On Kernel-Target Alignment. We describe a family of global optimization procedures", + " that automatically decompose optimization problems into smaller loosely coupled", + " problems, then combine the solutions of these with message passing algorithms.", + ] + + CohereEmbeddingOperator(input_text=texts, task_id="embedding_via_text") + CohereEmbeddingOperator(input_text=texts[0], task_id="embedding_via_task") + # [END howto_operator_cohere_embedding] + + +from tests.system.utils import get_test_run + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)