From 669cf2c9cb8cd683dd03fb2dc3709892263934f6 Mon Sep 17 00:00:00 2001 From: simon824 Date: Tue, 30 Jan 2024 17:52:36 +0800 Subject: [PATCH 1/4] refactor construct knowledge graph task --- hugegraph-llm/README.md | 62 +++- hugegraph-llm/examples/build_kg_test.py | 63 ++-- hugegraph-llm/src/config/config.ini | 2 +- .../src/hugegraph_llm/llms/ernie_bot.py | 4 +- .../operators/common_op/check_schema.py | 63 ++++ .../hugegraph_op/commit_to_hugegraph.py | 232 ++++----------- .../operators/hugegraph_op/schema_manager.py | 54 ++++ .../operators/kg_construction_task.py | 36 ++- .../operators/llm_op/disambiguate_data.py | 238 ++------------- .../operators/llm_op/info_extract.py | 276 ++++-------------- .../operators/common_op/test_check_schema.py | 86 ++++++ .../llm_op/test_disambiguate_data.py | 83 ++++++ .../operators/llm_op/test_info_extract.py | 142 +++++++++ 13 files changed, 676 insertions(+), 665 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py create mode 100644 hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py create mode 100644 hugegraph-llm/src/tests/operators/common_op/test_check_schema.py create mode 100644 hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py create mode 100644 hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index ab1eec2e5..cb8e3f43c 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -19,8 +19,64 @@ graph systems and large language models. ## Examples (knowledge graph construction by llm) -1. Start the HugeGraph database, you can do it via Docker. Refer to this [link](https://hub.docker.com/r/hugegraph/hugegraph) for guidance -2. Run example like `python hugegraph-llm/examples/build_kg_test.py` +> 1. Start the HugeGraph database, you can do it via Docker. Refer to this [link](https://hub.docker.com/r/hugegraph/hugegraph) for guidance +> 2. Run example like `python hugegraph-llm/examples/build_kg_test.py` +> +> Note: If you need a proxy to access OpenAI's API, please set your HTTP proxy in `build_kg_test.py`. -Note: If you need a proxy to access OpenAI's API, please set your HTTP proxy in `build_kg_test.py`. +The `KgBuilder` class is used to construct a knowledge graph. Here is a brief usage guide: +1. **Initialization**: The `KgBuilder` class is initialized with an instance of a language model. This can be obtained from the `LLMs` class. + +```python +from hugegraph_llm.llms.init_llm import LLMs +from hugegraph_llm.operators.kg_construction_task import KgBuilder + +TEXT = "" +builder = KgBuilder(LLMs().get_llm()) +( + builder + .import_schema(from_hugegraph="talent_graph").print_result() + .extract_triples(TEXT).print_result() + .disambiguate_word_sense().print_result() + .commit_to_hugegraph() + .run() +) +``` + +2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance,a user-defined schema or an extraction result. The method `print_result` can be chained to print the result. +```python +# Import schema from a HugeGraph instance +import_schema(from_hugegraph="talent_graph").print_result() +# Import schema from an extraction result +import_schema(from_extraction="xxx").print_result() +# Import schema from user-defined schema +import_schema(from_user_defined="xxx").print_result() +``` + +3. **Extract Triples**: The `extract_triples` method is used to extract triples from a text. The text should be passed as a string argument to the method. + +```python +TEXT = "Meet Sarah, a 30-year-old attorney, and her roommate, James, whom she's shared a home with since 2010." +extract_triples(TEXT).print_result() +``` + +4. **Disambiguate Word Sense**: The `disambiguate_word_sense` method is used to disambiguate the sense of words in the extracted triples. + +```python +disambiguate_word_sense().print_result() +``` + +5. **Commit to HugeGraph**: The `commit_to_hugegraph` method is used to commit the constructed knowledge graph to a HugeGraph instance. + +```python +commit_to_hugegraph().print_result() +``` + +6. **Run**: The `run` method is used to execute the chained operations. + +```python +run() +``` + +The methods of the `KgBuilder` class can be chained together to perform a sequence of operations. diff --git a/hugegraph-llm/examples/build_kg_test.py b/hugegraph-llm/examples/build_kg_test.py index d9b56f633..1274ed101 100644 --- a/hugegraph-llm/examples/build_kg_test.py +++ b/hugegraph-llm/examples/build_kg_test.py @@ -21,7 +21,8 @@ if __name__ == "__main__": - default_llm = LLMs().get_llm() + builder = KgBuilder(LLMs().get_llm()) + TEXT = ( "Meet Sarah, a 30-year-old attorney, and her roommate, James, whom she's shared a home with" " since 2010. James, in his professional life, works as a journalist. Additionally, Sarah" @@ -31,47 +32,29 @@ " their distinctive digital presence through their respective webpages, showcasing their" " varied interests and experiences." ) - builder = KgBuilder(default_llm) - # spo triple extract - builder.extract_spo_triple(TEXT).print_result().commit_to_hugegraph(spo=True).run() - # build kg with only text - builder.extract_nodes_relationships(TEXT).disambiguate_word_sense().commit_to_hugegraph().run() - # build kg with text and schemas - nodes_schemas = [ - { - "label": "Person", - "primary_key": "name", - "properties": { - "age": "int", - "name": "text", - "occupation": "text", - }, - }, - { - "label": "Webpage", - "primary_key": "name", - "properties": {"name": "text", "url": "text"}, - }, - ] - relationships_schemas = [ - { - "start": "Person", - "end": "Person", - "type": "roommate", - "properties": {"start": "int"}, - }, - { - "start": "Person", - "end": "Webpage", - "type": "owns", - "properties": {}, - }, - ] + schema = { + "vertices": [ + {"vertex_label": "person", "properties": ["name", "age", "occupation"]}, + {"vertex_label": "webpage", "properties": ["name", "url"]}, + ], + "edges": [ + { + "edge_label": "roommate", + "source_vertex_label": "person", + "target_vertex_label": "person", + "properties": {}, + } + ], + } ( - builder.parse_text_to_data_with_schemas(TEXT, nodes_schemas, relationships_schemas) - .disambiguate_data_with_schemas() - .commit_data_to_kg() + builder + .import_schema(from_hugegraph="talent_graph").print_result() + # .import_schema(from_extraction="fefe").print_result().run() + # .import_schema(from_input=schema).print_result() + .extract_triples(TEXT).print_result() + .disambiguate_word_sense() + .commit_to_hugegraph() .run() ) diff --git a/hugegraph-llm/src/config/config.ini b/hugegraph-llm/src/config/config.ini index a45d1f02c..6f9219ff3 100644 --- a/hugegraph-llm/src/config/config.ini +++ b/hugegraph-llm/src/config/config.ini @@ -27,6 +27,6 @@ graph = hugegraph type = openai api_key = xxx secret_key = xxx -ernie_url = https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token= +ernie_url = https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token= model_name = gpt-3.5-turbo-16k max_token = 4000 diff --git a/hugegraph-llm/src/hugegraph_llm/llms/ernie_bot.py b/hugegraph-llm/src/hugegraph_llm/llms/ernie_bot.py index 1ab5613c2..085f7a06e 100644 --- a/hugegraph-llm/src/hugegraph_llm/llms/ernie_bot.py +++ b/hugegraph-llm/src/hugegraph_llm/llms/ernie_bot.py @@ -54,9 +54,9 @@ def generate( messages = [{"role": "user", "content": prompt}] url = self.base_url + self.get_access_token() # parameter check failed, temperature range is (0, 1.0] - payload = json.dumps({"messages": messages, "temperature": 0.00000000001}) + payload = json.dumps({"messages": messages, "temperature": 0.1}) headers = {"Content-Type": "application/json"} - response = requests.request("POST", url, headers=headers, data=payload, timeout=10) + response = requests.request("POST", url, headers=headers, data=payload, timeout=30) if response.status_code != 200: raise Exception( f"Request failed with code {response.status_code}, message: {response.text}" diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py new file mode 100644 index 000000000..bd491bdfe --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from typing import Any + + +class CheckSchema: + def __init__(self, data): + self.result = None + self.data = data + + def run(self, schema=None) -> Any: + data = self.data or schema + if not isinstance(data, dict): + raise ValueError("Input data is not a dictionary.") + if "vertices" not in data or "edges" not in data: + raise ValueError("Input data does not contain 'vertices' or 'edges'.") + if not isinstance(data["vertices"], list) or not isinstance(data["edges"], list): + raise ValueError("'vertices' or 'edges' in input data is not a list.") + for vertex in data["vertices"]: + if not isinstance(vertex, dict): + raise ValueError("Vertex in input data is not a dictionary.") + if "vertex_label" not in vertex: + raise ValueError("Vertex in input data does not contain 'vertex_label'.") + if not isinstance(vertex["vertex_label"], str): + raise ValueError("'vertex_label' in vertex is not of correct type.") + for edge in data["edges"]: + if not isinstance(edge, dict): + raise ValueError("Edge in input data is not a dictionary.") + if ( + "edge_label" not in edge + or "source_vertex_label" not in edge + or "target_vertex_label" not in edge + ): + raise ValueError( + "Edge in input data does not contain " + "'edge_label', 'source_vertex_label', 'target_vertex_label'." + ) + if ( + not isinstance(edge["edge_label"], str) + or not isinstance(edge["source_vertex_label"], str) + or not isinstance(edge["target_vertex_label"], str) + ): + raise ValueError( + "'edge_label', 'source_vertex_label', 'target_vertex_label' " + "in edge is not of correct type." + ) + return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 350fb950e..62443c2fd 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -19,179 +19,10 @@ from hugegraph_llm.utils.config import Config from hugegraph_llm.utils.constants import Constants from pyhugegraph.client import PyHugeClient +from pyhugegraph.utils.exceptions import NotFoundError -def generate_new_relationships(nodes_schemas_data, relationships_data): - label_id = {} - i = 1 - old_label = [] - for item in nodes_schemas_data: - label = item["label"] - if label in old_label: - continue - label_id[label] = i - i += 1 - old_label.append(label) - new_relationships_data = [] - for relationship in relationships_data: - start = relationship["start"] - end = relationship["end"] - relationships_type = relationship["type"] - properties = relationship["properties"] - new_start = [] - new_end = [] - for key, value in label_id.items(): - for key1, value1 in start.items(): - if key1 == key: - new_start = f"{value}" + ":" + f"{value1}" - for key1, value1 in end.items(): - if key1 == key: - new_end = f"{value}" + ":" + f"{value1}" - relationships_data = {} - relationships_data["start"] = new_start - relationships_data["end"] = new_end - relationships_data["type"] = relationships_type - relationships_data["properties"] = properties - new_relationships_data.append(relationships_data) - return new_relationships_data - - -def generate_schema_properties(data): - schema_properties_statements = [] - if len(data) == 3: - for item in data: - properties = item["properties"] - for key, value in properties.items(): - if value == "int": - schema_properties_statements.append( - f"schema.propertyKey('{key}').asInt().ifNotExist().create()" - ) - elif value == "text": - schema_properties_statements.append( - f"schema.propertyKey('{key}').asText().ifNotExist().create()" - ) - else: - for item in data: - properties = item["properties"] - for key, value in properties.items(): - if value == "int": - schema_properties_statements.append( - f"schema.propertyKey('{key}').asInt().ifNotExist().create()" - ) - elif value == "text": - schema_properties_statements.append( - f"schema.propertyKey('{key}').asText().ifNotExist().create()" - ) - return schema_properties_statements - - -def generate_schema_nodes(data): - schema_nodes_statements = [] - for item in data: - label = item["label"] - primary_key = item["primary_key"] - properties = item["properties"] - schema_statement = f"schema.vertexLabel('{label}').properties(" - schema_statement += ", ".join(f"'{prop}'" for prop in properties.keys()) - schema_statement += ").nullableKeys(" - schema_statement += ", ".join( - f"'{prop}'" for prop in properties.keys() if prop != primary_key - ) - schema_statement += ( - f").usePrimaryKeyId().primaryKeys('{primary_key}').ifNotExist().create()" - ) - schema_nodes_statements.append(schema_statement) - return schema_nodes_statements - - -def generate_schema_relationships(data): - schema_relationships_statements = [] - for item in data: - start = item["start"] - end = item["end"] - schema_relationships_type = item["type"] - properties = item["properties"] - schema_statement = ( - f"schema.edgeLabel('{schema_relationships_type}')" - f".sourceLabel('{start}').targetLabel('{end}').properties(" - ) - schema_statement += ", ".join(f"'{prop}'" for prop in properties.keys()) - schema_statement += ").nullableKeys(" - schema_statement += ", ".join(f"'{prop}'" for prop in properties.keys()) - schema_statement += ").ifNotExist().create()" - schema_relationships_statements.append(schema_statement) - return schema_relationships_statements - - -def generate_nodes(data): - nodes = [] - for item in data: - label = item["label"] - properties = item["properties"] - nodes.append(f"g.addVertex('{label}', {properties})") - return nodes - - -def generate_relationships(data): - relationships = [] - for item in data: - start = item["start"] - end = item["end"] - types = item["type"] - properties = item["properties"] - relationships.append(f"g.addEdge('{types}', '{start}', '{end}', {properties})") - return relationships - - -class CommitDataToKg: - def __init__(self): - config = Config(section=Constants.HUGEGRAPH_CONFIG) - self.client = PyHugeClient( - config.get_graph_ip(), - config.get_graph_port(), - config.get_graph_user(), - config.get_graph_pwd(), - config.get_graph_name(), - ) - self.schema = self.client.schema() - - def run(self, data: dict): - nodes = data["nodes"] - relationships = data["relationships"] - nodes_schemas = data["nodes_schemas"] - relationships_schemas = data["relationships_schemas"] - # properties schema - schema_nodes_properties = generate_schema_properties(nodes_schemas) - schema_relationships_properties = generate_schema_properties(relationships_schemas) - for schema_nodes_property in schema_nodes_properties: - exec(schema_nodes_property) - - for schema_relationships_property in schema_relationships_properties: - exec(schema_relationships_property) - - # nodes schema - schema_nodes = generate_schema_nodes(nodes_schemas) - for schema_node in schema_nodes: - exec(schema_node) - - # relationships schema - schema_relationships = generate_schema_relationships(relationships_schemas) - for schema_relationship in schema_relationships: - exec(schema_relationship) - - # nodes - nodes = generate_nodes(nodes) - for node in nodes: - exec(node) - - # relationships - new_relationships = generate_new_relationships(nodes_schemas, relationships) - relationships_schemas = generate_relationships(new_relationships) - for relationship in relationships_schemas: - exec(relationship) - - -class CommitSpoToKg: +class CommitToKg: def __init__(self): config = Config(section=Constants.HUGEGRAPH_CONFIG) self.client = PyHugeClient( @@ -204,6 +35,59 @@ def __init__(self): self.schema = self.client.schema() def run(self, data: dict): + if "schema" not in data: + self.schema_free_mode(data["triples"]) + else: + schema = data["schema"] + vertices = data["vertices"] + edges = data["edges"] + self.init_schema(schema) + self.init_graph(vertices, edges) + + def init_graph(self, vertices, edges): + vids = {} + for vertex in vertices: + label = vertex["label"] + properties = vertex["properties"] + try: + vid = self.client.graph().addVertex(label, properties).id + vids[vertex["name"]] = vid + except NotFoundError as e: + print(e) + for edge in edges: + start = vids[edge["start"]] + end = vids[edge["end"]] + types = edge["type"] + properties = edge["properties"] + try: + self.client.graph().addEdge(types, start, end, properties) + except NotFoundError as e: + print(e) + + def init_schema(self, schema): + vertices = schema["vertices"] + edges = schema["edges"] + + for vertex in vertices: + vertex_label = vertex["vertex_label"] + properties = vertex["properties"] + for prop in properties: + self.schema.propertyKey(prop).asText().ifNotExist().create() + self.schema.vertexLabel(vertex_label).properties(*properties).nullableKeys( + *properties[1:] + ).usePrimaryKeyId().primaryKeys(properties[0]).ifNotExist().create() + for edge in edges: + edge_label = edge["edge_label"] + source_vertex_label = edge["source_vertex_label"] + target_vertex_label = edge["target_vertex_label"] + properties = edge["properties"] + for prop in properties: + self.schema.propertyKey(prop).asText().ifNotExist().create() + self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel( + target_vertex_label + ).properties(*properties).nullableKeys(*properties).ifNotExist().create() + + def schema_free_mode(self, data): self.schema.propertyKey("name").asText().ifNotExist().create() self.schema.vertexLabel("vertex").useCustomizeStringId().properties( "name" @@ -220,9 +104,9 @@ def run(self, data: dict): ).secondary().ifNotExist().create() for item in data: - s = item[0] - p = item[1] - o = item[2] + s = item[0].strip() + p = item[1].strip() + o = item[2].strip() s_id = self.client.graph().addVertex("vertex", {"name": s}, id=s).id t_id = self.client.graph().addVertex("vertex", {"name": o}, id=o).id self.client.graph().addEdge("edge", s_id, t_id, {"name": p}) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py new file mode 100644 index 000000000..c3cbd45a9 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from hugegraph_llm.utils.config import Config +from hugegraph_llm.utils.constants import Constants +from pyhugegraph.client import PyHugeClient + + +class SchemaManager: + def __init__(self, graph_name: str): + config = Config(section=Constants.HUGEGRAPH_CONFIG) + self.graph_name = graph_name + self.client = PyHugeClient( + config.get_graph_ip(), + config.get_graph_port(), + graph_name, + config.get_graph_user(), + config.get_graph_pwd(), + ) + self.schema = self.client.schema() + + def run(self, data: dict): + schema = self.schema.getSchema() + vertices = [] + for vl in schema["vertexlabels"]: + vertex = {"vertex_label": vl["name"], "properties": vl["properties"]} + vertices.append(vertex) + edges = [] + for el in schema["edgelabels"]: + edge = { + "edge_label": el["name"], + "source_vertex_label": el["source_label"], + "target_vertex_label": el["target_label"], + "properties": el["properties"], + } + edges.append(edge) + if not vertices and not edges: + raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!") + return {"vertices": vertices, "edges": edges} diff --git a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py index 5b8d019a4..082058dae 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py @@ -17,11 +17,10 @@ from hugegraph_llm.llms.base import BaseLLM +from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.common_op.print_result import PrintResult -from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import ( - CommitDataToKg, - CommitSpoToKg, -) +from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import CommitToKg +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData from hugegraph_llm.operators.llm_op.info_extract import InfoExtract @@ -32,28 +31,27 @@ def __init__(self, llm: BaseLLM): self.llm = llm self.result = None - def extract_nodes_relationships( - self, text: str, nodes_schemas=None, relationships_schemas=None - ): - if nodes_schemas and relationships_schemas: - self.operators.append(InfoExtract(self.llm, text, nodes_schemas, relationships_schemas)) + def import_schema(self, from_hugegraph=None, from_extraction=None, from_user_defined=None): + if from_hugegraph: + self.operators.append(SchemaManager(from_hugegraph)) + elif from_user_defined: + self.operators.append(CheckSchema(from_user_defined)) + elif from_extraction: + raise Exception("Not implemented yet") else: - self.operators.append(InfoExtract(self.llm, text)) + raise Exception("No input data") return self - def extract_spo_triple(self, text: str): - self.operators.append(InfoExtract(self.llm, text, spo=True)) + def extract_triples(self, text: str): + self.operators.append(InfoExtract(self.llm, text)) return self - def disambiguate_word_sense(self, with_schemas=False): - self.operators.append(DisambiguateData(self.llm, with_schemas)) + def disambiguate_word_sense(self): + self.operators.append(DisambiguateData(self.llm)) return self - def commit_to_hugegraph(self, spo=False): - if spo: - self.operators.append(CommitSpoToKg()) - else: - self.operators.append(CommitDataToKg()) + def commit_to_hugegraph(self): + self.operators.append(CommitToKg()) return self def print_result(self): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index f7c16b9e4..a08793428 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -16,228 +16,38 @@ # under the License. -import json -import re -from itertools import groupby from typing import Dict, List, Any -from hugegraph_llm.operators.llm_op.unstructured_data_utils import ( - nodes_text_to_list_of_dict, - relationships_text_to_list_of_dict, - relationships_schemas_text_to_list_of_dict, - nodes_schemas_text_to_list_of_dict, -) from hugegraph_llm.llms.base import BaseLLM +from hugegraph_llm.operators.llm_op.info_extract import extract_by_regex -def disambiguate_nodes() -> str: - return """ -Your task is to identify if there are duplicated nodes and if so merge them into one nod. Only merge the nodes that refer to the same entity. -You will be given different datasets of nodes and some of these nodes may be duplicated or refer to the same entity. -The datasets contains nodes in the form [ENTITY_ID, TYPE, PROPERTIES]. When you have completed your task please give me the -resulting nodes in the same format. Only return the nodes and relationships no other text. If there is no duplicated nodes return the original nodes. - -Here is an example -The input you will be given: -["Alice", "Person", {"age" : 25, "occupation": "lawyer", "name":"Alice"}], ["Bob", "Person", {"occupation": "journalist", "name": "Bob"}], ["alice.com", "Webpage", {"url": "www.alice.com"}], ["bob.com", "Webpage", {"url": "www.bob.com"}], ["Bob", "Person", {"occupation": "journalist", "name": "Bob"}] -The output you need to provide: -["Alice", "Person", {"age" : 25, "occupation": "lawyer", "name":"Alice"}], ["Bob", "Person", {"occupation": "journalist", "name": "Bob"}], ["alice.com", "Webpage", {"url": "www.alice.com"}], ["bob.com", "Webpage", {"url": "www.bob.com"}] -""" - - -def disambiguate_relationships() -> str: - return """ -Your task is to identify if a set of relationships make sense. -If they do not make sense please remove them from the dataset. -Some relationships may be duplicated or refer to the same entity. -Please merge relationships that refer to the same entity. -The datasets contains relationships in the form [{"ENTITY_TYPE_1": "ENTITY_ID_1"}, RELATIONSHIP, {"ENTITY_TYPE_2": "ENTITY_ID_2"}, PROPERTIES]. -You will also be given a set of ENTITY_IDs that are valid. -Some relationships may use ENTITY_IDs that are not in the valid set but refer to a entity in the valid set. -If a relationships refer to a ENTITY_ID in the valid set please change the ID so it matches the valid ID. -When you have completed your task please give me the valid relationships in the same format. Only return the relationships no other text. - -Here is an example -The input you will be given: -[{"Person": "Alice"}, "roommate", {"Person": "bob"}, {"start": 2021}], [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}], [{"Person": "Bob"}, "owns", {"Webpage": "bob.com"}, {}], [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}] -The output you need to provide: -[{"Person": "Alice"}, "roommate", {"Person": "bob"}, {"start": 2021}], [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}], [{"Person": "Bob"}, "owns", {"Webpage": "bob.com"}, {}] -""" - - -def disambiguate_nodes_schemas() -> str: - return """ -Your task is to identify if there are duplicated nodes schemas and if so merge them into one nod. Only merge the nodes schemas that refer to the same entty_types. -You will be given different node schemas, some of which may duplicate or reference the same entty_types. Note: For node schemas with the same entty_types, you need to merge them while merging all properties of the entty_types. -The datasets contains nodes schemas in the form [ENTITY_TYPE, PRIMARY KEY, PROPERTIES]. When you have completed your task please give me the -resulting nodes schemas in the same format. Only return the nodes schemas no other text. If there is no duplicated nodes return the original nodes schemas. - -Here is an example -The input you will be given: -["Person", "name", {"age": "int", "name": "text", "occupation": "text"}], ["Webpage", "url", {url: "text"}], ["Webpage", "url", {url: "text"}] -The output you need to provide: -["Person", "name", {"age": "int", "name": "text", "occupation": "text"}], ["Webpage", "url", {url: "text"}] -""" - - -def disambiguate_relationships_schemas() -> str: - return """ -Your task is to identify if a set of relationships schemas make sense. -If they do not make sense please remove them from the dataset. -Some relationships may be duplicated or refer to the same label. -Please merge relationships that refer to the same label. -The datasets contains relationships in the form [LABEL_ID_1, RELATIONSHIP, LABEL_ID_2, PROPERTIES]. -You will also be given a set of LABELS_IDs that are valid. -Some relationships may use LABELS_IDs that are not in the valid set but refer to a LABEL in the valid set. -If a relationships refer to a LABELS_IDs in the valid set please change the ID so it matches the valid ID. -When you have completed your task please give me the valid relationships in the same format. Only return the relationships no other text. - -Here is an example -["Person", "roommate", "Person", {"start": 2021}], ["Person", "owns", "Webpage", {}], ["Person", "roommate", "Person", {"start": 2021}] -The output you need to provide: -["Person", "roommate", "Person", {"start": 2021}], ["Person", "owns", "Webpage", {}] -""" - - -def generate_prompt(data) -> str: - return f""" Here is the data: -{data} -""" - - -INTERNAL_REGEX = r"\[(.*?)\]" +def generate_disambiguate_prompt(triples): + return f""" + Your task is to disambiguate the following triples: + {triples} + If the second element of the triples expresses the same meaning but in different ways, + unify them and keep the most concise expression. + + For example, if the input is: + [("Alice", "friend", "Bob"), ("Simon", "is friends with", "Bob")] + + The output should be: + [("Alice", "friend", "Bob"), ("Simon", "friend", "Bob")] + """ class DisambiguateData: - def __init__(self, llm: BaseLLM, is_user_schema: bool) -> None: + def __init__(self, llm: BaseLLM) -> None: self.llm = llm - self.is_user_schema = is_user_schema def run(self, data: Dict) -> Dict[str, List[Any]]: - nodes = sorted(data["nodes"], key=lambda x: x.get("label", "")) - relationships = data["relationships"] - nodes_schemas = data["nodes_schemas"] - relationships_schemas = data["relationships_schemas"] - new_nodes = [] - new_relationships = [] - new_nodes_schemas = [] - new_relationships_schemas = [] - - node_groups = groupby(nodes, lambda x: x["label"]) - for group in node_groups: - dis_string = "" - nodes_in_group = list(group[1]) - if len(nodes_in_group) == 1: - new_nodes.extend(nodes_in_group) - continue - - for node in nodes_in_group: - dis_string += ( - '["' - + node["name"] - + '", "' - + node["label"] - + '", ' - + json.dumps(node["properties"]) - + "]\n" - ) - - messages = [ - {"role": "system", "content": disambiguate_nodes()}, - {"role": "user", "content": generate_prompt(dis_string)}, - ] - raw_nodes = self.llm.generate(messages) - n = re.findall(INTERNAL_REGEX, raw_nodes) - new_nodes.extend(nodes_text_to_list_of_dict(n)) - - relationship_data = "" - for relation in relationships: - relationship_data += ( - '["' - + json.dumps(relation["start"]) - + '", "' - + relation["type"] - + '", "' - + json.dumps(relation["end"]) - + '", ' - + json.dumps(relation["properties"]) - + "]\n" - ) - - node_labels = [node["name"] for node in new_nodes] - relationship_data += "Valid Nodes:\n" + "\n".join(node_labels) - - messages = [ - { - "role": "system", - "content": disambiguate_relationships(), - }, - {"role": "user", "content": generate_prompt(relationship_data)}, - ] - raw_relationships = self.llm.generate(messages) - rels = re.findall(INTERNAL_REGEX, raw_relationships) - new_relationships.extend(relationships_text_to_list_of_dict(rels)) - - if not self.is_user_schema: - nodes_schemas_data = "" - for node_schema in nodes_schemas: - nodes_schemas_data += ( - '["' - + node_schema["label"] - + '", ' - + node_schema["primary_key"] - + '", ' - + json.dumps(node_schema["properties"]) - + "]\n" - ) - - messages = [ - {"role": "system", "content": disambiguate_nodes_schemas()}, - {"role": "user", "content": generate_prompt(nodes_schemas_data)}, - ] - raw_nodes_schemas = self.llm.generate(messages) - n = re.findall(INTERNAL_REGEX, raw_nodes_schemas) - new_nodes_schemas.extend(nodes_schemas_text_to_list_of_dict(n)) - - relationships_schemas_data = "" - for relationships_schema in relationships_schemas: - relationships_schemas_data += ( - '["' - + relationships_schema["start"] - + '", "' - + relationships_schema["type"] - + '", "' - + relationships_schema["end"] - + '", ' - + json.dumps(relationships_schema["properties"]) - + "]\n" - ) - - node_schemas_labels = [nodes_schemas["label"] for nodes_schemas in new_nodes_schemas] - relationships_schemas_data += "Valid Labels:\n" + "\n".join(node_schemas_labels) - - messages = [ - { - "role": "system", - "content": disambiguate_relationships_schemas(), - }, - { - "role": "user", - "content": generate_prompt(relationships_schemas_data), - }, - ] - raw_relationships_schemas = self.llm.generate(messages) - schemas_rels = re.findall(INTERNAL_REGEX, raw_relationships_schemas) - new_relationships_schemas.extend( - relationships_schemas_text_to_list_of_dict(schemas_rels) - ) - else: - new_nodes_schemas = nodes_schemas - new_relationships_schemas = relationships_schemas - - return { - "nodes": new_nodes, - "relationships": new_relationships, - "nodes_schemas": new_nodes_schemas, - "relationships_schemas": new_relationships_schemas, - } + # only disambiguate triples + if "triples" in data: + triples = data["triples"] + prompt = generate_disambiguate_prompt(triples) + llm_output = self.llm.generate(prompt=prompt) + data = {"triples": []} + extract_by_regex(llm_output, data) + print(f"LLM input:{prompt} \n output: {llm_output} \n data: {data}") + return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 7386fd0bb..18b4083ee 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -19,141 +19,41 @@ from typing import List, Any, Dict from hugegraph_llm.llms.base import BaseLLM -from hugegraph_llm.operators.llm_op.unstructured_data_utils import ( - nodes_text_to_list_of_dict, - nodes_schemas_text_to_list_of_dict, - relationships_schemas_text_to_list_of_dict, - relationships_text_to_list_of_dict, -) -def generate_system_message() -> str: - return """You are a data scientist working for a company that is building a graph database. - Your task is to extract information from data and convert it into a graph database. Provide a - set of Nodes in the form [ENTITY_ID, TYPE, PROPERTIES] and a set of relationships in the form - [ENTITY_ID_1, RELATIONSHIP, ENTITY_ID_2, PROPERTIES] and a set of NodesSchemas in the form [ - ENTITY_TYPE, PRIMARY_KEY, PROPERTIES] and a set of RelationshipsSchemas in the form [ - ENTITY_TYPE_1, RELATIONSHIP, ENTITY_TYPE_2, PROPERTIES] It is important that the ENTITY_ID_1 - and ENTITY_ID_2 exists as nodes with a matching ENTITY_ID. If you can't pair a relationship - with a pair of nodes don't add it. When you find a node or relationship you want to add try - to create a generic TYPE for it that describes the entity you can also think of it as a label. - - Here is an example The input you will be given: Data: Alice lawyer and is 25 years old and Bob - is her roommate since 2001. Bob works as a journalist. Alice owns a the webpage www.alice.com - and Bob owns the webpage www.bob.com. The output you need to provide: Nodes: ["Alice", "Person", - {"age": 25, "occupation": "lawyer", "name": "Alice"}], ["Bob", "Person", {"occupation": - "journalist", "name": "Bob"}], ["alice.com", "Webpage", {"name": "alice.com", - "url": "www.alice.com"}], ["bob.com", "Webpage", {"name": "bob.com", "url": "www.bob.com"}] - Relationships: [{"Person": "Alice"}, "roommate", {"Person": "Bob"}, {"start": 2021}], - [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}], [{"Person": "Bob"}, "owns", - {"Webpage": "bob.com"}, {}] NodesSchemas: ["Person", "name", {"age": "int", - "name": "text", "occupation": - "text"}], ["Webpage", "name", {"name": "text", "url": "text"}] RelationshipsSchemas :["Person", - "roommate", "Person", {"start": "int"}], ["Person", "owns", "Webpage", {}]""" - - -def generate_ernie_prompt_spo(data) -> str: - return f"""Extract subject-verb-object (SPO) triples from text strictly according to the - following format, each structure has only three elements: ("vertex_1", "edge", "vertex_2"). - for example: - Alice lawyer and is 25 years old and Bob is her roommate since 2001. Bob works as a journalist. - Alice owns a the webpage www.alice.com and Bob owns the webpage www.bob.com - output:[("Alice", "Age", "25"),("Alice", "Profession", "lawyer"),("Bob", "Job", "journalist"), - ("Alice", "Roommate of", "Bob"),("Alice", "Owns", "http://www.alice.com"), - ("Bob", "Owns", "http://www.bob.com")] +def generate_extract_triple_prompt(text, schema=None) -> str: + if schema: + return f""" + Given the graph schema: {schema} - The extracted text is: {data}""" - - -def generate_ernie_message(data) -> str: - return ( - """You are a data scientist working for a company that is building a graph database. - Your task is to extract information from data and convert it into a graph database. Provide - a set of Nodes in the form [ENTITY_ID, TYPE, PROPERTIES] and a set of relationships in the - form [ENTITY_ID_1, RELATIONSHIP, ENTITY_ID_2, PROPERTIES] and a set of NodesSchemas in the - form [ENTITY_TYPE, PRIMARY_KEY, PROPERTIES] and a set of RelationshipsSchemas in the form [ - ENTITY_TYPE_1, RELATIONSHIP, ENTITY_TYPE_2, PROPERTIES] It is important that the ENTITY_ID_1 - and ENTITY_ID_2 exists as nodes with a matching ENTITY_ID. If you can't pair a relationship - with a pair of nodes don't add it. When you find a node or relationship you want to add try - to create a generic TYPE for it that describes the entity you can also think of it as a - label. - - Here is an example The input you will be given: Data: Alice lawyer and is 25 years old and - Bob is her roommate since 2001. Bob works as a journalist. Alice owns a the webpage - www.alice.com and Bob owns the webpage www.bob.com. The output you need to provide: - Nodes: ["Alice", "Person", {"age": 25, "occupation": "lawyer", "name": "Alice"}], - ["Bob", "Person", {"occupation": - "journalist", "name": "Bob"}], ["alice.com", "Webpage", {"name": "alice.com", - "url": "www.alice.com"}], ["bob.com", "Webpage", {"name": "bob.com", "url": "www.bob.com"}] - Relationships: [{"Person": "Alice"}, "roommate", {"Person": "Bob"}, {"start": 2021}], - [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}], [{"Person": "Bob"}, "owns", - {"Webpage": "bob.com"}, {}] NodesSchemas: ["Person", "name", {"age": "int", "name": - "text", "occupation": "text"}], ["Webpage", "name", {"name": "text", "url": "text"}] - RelationshipsSchemas :["Person", "roommate", "Person", {"start": "int"}], - ["Person", "owns", "Webpage", {}] + Based on the above schema, extract triples from the following text. + The output format must be: (X,Y,Z) - LABEL + In this format, Y must be a value from "properties" or "edge_label", + and LABEL must be X's vertex_label or Y's edge_label. - Now extract information from the following data: + The extracted text is: {text} """ - + data - ) - - -def generate_system_message_with_schemas() -> str: - return """You are a data scientist working for a company that is building a graph database. - Your task is to extract information from data and convert it into a graph database. Provide a - set of Nodes in the form [ENTITY_ID, TYPE, PROPERTIES] and a set of relationships in the form - [ENTITY_ID_1, RELATIONSHIP, ENTITY_ID_2, PROPERTIES] and a set of NodesSchemas in the form [ - ENTITY_TYPE, PRIMARY_KEY, PROPERTIES] and a set of RelationshipsSchemas in the form [ - ENTITY_TYPE_1, RELATIONSHIP, ENTITY_TYPE_2, PROPERTIES] It is important that the ENTITY_ID_1 - and ENTITY_ID_2 exists as nodes with a matching ENTITY_ID. If you can't pair a relationship - with a pair of nodes don't add it. When you find a node or relationship you want to add try - to create a generic TYPE for it that describes the entity you can also think of it as a label. - - Here is an example The input you will be given: Data: Alice lawyer and is 25 years old and Bob - is her roommate since 2001. Bob works as a journalist. Alice owns a the webpage www.alice.com - and Bob owns the webpage www.bob.com. NodesSchemas: ["Person", "name", {"age": "int", - "name": "text", "occupation": "text"}], ["Webpage", "name", {"name": "text", "url": "text"}] - RelationshipsSchemas :["Person", "roommate", "Person", {"start": "int"}], ["Person", "owns", - "Webpage", {}] The output you need to provide: Nodes: ["Alice", "Person", {"age": 25, - "occupation": "lawyer", "name": "Alice"}], ["Bob", "Person", {"occupation": "journalist", - "name": "Bob"}], ["alice.com", "Webpage", {"name": "alice.com", "url": "www.alice.com"}], - ["bob.com", "Webpage", {"name": "bob.com", "url": "www.bob.com"}] Relationships: [{"Person": - "Alice"}, "roommate", {"Person": "Bob"}, {"start": 2021}], [{"Person": "Alice"}, "owns", - {"Webpage": "alice.com"}, {}], [{"Person": "Bob"}, "owns", {"Webpage": "bob.com"}, - {}] NodesSchemas: ["Person", "name", {"age": "int", "name": "text", "occupation": "text"}], - ["Webpage", "name", {"name": "text", "url": "text"}] RelationshipsSchemas :["Person", - "roommate", "Person", {"start": "int"}], ["Person", "owns", "Webpage", {}] - """ - - -def generate_prompt(data) -> str: - return f""" - Data: {data} - """ - - -def generate_prompt_with_schemas(data, nodes_schemas, relationships_schemas) -> str: - return f""" - Data: {data} - NodesSchemas: {nodes_schemas} - RelationshipsSchemas: {relationships_schemas} - """ - + return f"""Extract subject-verb-object (SPO) triples from text strictly according to the + following format, each structure has only three elements: ("vertex_1", "edge", "vertex_2"). + for example: + Alice lawyer and is 25 years old and Bob is her roommate since 2001. Bob works as a journalist. + Alice owns a the webpage www.alice.com and Bob owns the webpage www.bob.com + output:[("Alice", "Age", "25"),("Alice", "Profession", "lawyer"),("Bob", "Job", "journalist"), + ("Alice", "Roommate of", "Bob"),("Alice", "Owns", "http://www.alice.com"), + ("Bob", "Owns", "http://www.bob.com")] -def split_string(string, max_length) -> List[str]: - return [string[i : i + max_length] for i in range(0, len(string), max_length)] + The extracted text is: {text}""" -def split_string_to_fit_token_space( - llm: BaseLLM, string: str, token_use_per_string: int -) -> List[str]: - allowed_tokens = llm.max_allowed_token_length() - token_use_per_string - chunked_data = split_string(string, 500) +def fit_token_space_by_split_text(llm: BaseLLM, text: str, prompt_token: int) -> List[str]: + max_length = 500 + allowed_tokens = llm.max_allowed_token_length() - prompt_token + chunked_data = [text[i : i + max_length] for i in range(0, len(text), max_length)] combined_chunks = [] current_chunk = "" for chunk in chunked_data: if ( - llm.num_tokens_from_string(current_chunk) + llm.num_tokens_from_string(chunk) + int(llm.num_tokens_from_string(current_chunk)) + int(llm.num_tokens_from_string(chunk)) < allowed_tokens ): current_chunk += chunk @@ -165,51 +65,34 @@ def split_string_to_fit_token_space( return combined_chunks -def get_spo_from_result(result): - res = [] - for row in result: - row = row.replace("\\n", "").replace("\\", "") - pattern = r'\("(.*?)", "(.*?)", "(.*?)"\)' - res += re.findall(pattern, row) - return res +def extract_by_regex(text, triples): + text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") + pattern = r"\((.*?), (.*?), (.*?)\)" + triples["triples"] += re.findall(pattern, text) + +def extract_by_regex_with_schema(schema, text, graph): + text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") + pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" + matches = re.findall(pattern, text) -def get_nodes_and_relationships_from_result(result): - regex = ( - r"Nodes:\s+(.*?)\s?\s?Relationships:\s+(.*?)\s?\s?NodesSchemas:\s+(.*?)\s?\s?\s?" - r"RelationshipsSchemas:\s?\s?(.*)" - ) - internal_regex = r"\[(.*?)\]" - nodes = [] - relationships = [] - nodes_schemas = [] - relationships_schemas = [] - for row in result: - row = row.replace("\n", "") - parsing = re.search(regex, row, flags=re.S) - if parsing is None: + vertices_dict = {} + for match in matches: + s, p, o, label = [item.strip() for item in match] + if None in [label, s, p, o]: continue - raw_nodes = str(parsing.group(1)) - raw_relationships = parsing.group(2) - raw_nodes_schemas = parsing.group(3) - raw_relationships_schemas = parsing.group(4) - nodes.extend(re.findall(internal_regex, raw_nodes)) - relationships.extend(re.findall(internal_regex, raw_relationships)) - nodes_schemas.extend(re.findall(internal_regex, raw_nodes_schemas)) - relationships_schemas.extend(re.findall(internal_regex, raw_relationships_schemas)) - result = { - "nodes": [], - "relationships": [], - "nodes_schemas": [], - "relationships_schemas": [], - } - result["nodes"].extend(nodes_text_to_list_of_dict(nodes)) - result["relationships"].extend(relationships_text_to_list_of_dict(relationships)) - result["nodes_schemas"].extend(nodes_schemas_text_to_list_of_dict(nodes_schemas)) - result["relationships_schemas"].extend( - relationships_schemas_text_to_list_of_dict(relationships_schemas) - ) - return result + for vertex in schema["vertices"]: + if vertex["vertex_label"] == label and p in vertex["properties"]: + if (s, label) not in vertices_dict: + vertices_dict[(s, label)] = {"name": s, "label": label, "properties": {p: o}} + else: + vertices_dict[(s, label)]["properties"].update({p: o}) + break + for edge in schema["edges"]: + if edge["edge_label"] == label: + graph["edges"].append({"start": s, "end": o, "type": label, "properties": {}}) + break + graph["vertices"] = list(vertices_dict.values()) class InfoExtract: @@ -217,58 +100,27 @@ def __init__( self, llm: BaseLLM, text: str, - nodes_schemas=None, - relationships_schemas=None, - spo=False, ) -> None: self.llm = llm self.text = text - self.nodes_schemas = nodes_schemas - self.relationships_schemas = relationships_schemas - self.spo = spo - def process(self, chunk): - if self.llm.get_llm_type() == "openai": - messages = [ - {"role": "system", "content": self.generate_system_message()}, - {"role": "user", "content": self.generate_prompt(chunk)}, - ] - elif self.llm.get_llm_type() == "ernie": - if self.spo: - messages = [{"role": "user", "content": generate_ernie_prompt_spo(chunk)}] - else: - messages = [{"role": "user", "content": generate_ernie_message(chunk)}] - else: - raise Exception("llm type is not supported !") - output = self.llm.generate(messages) - return output - - def generate_system_message(self) -> str: - if self.nodes_schemas and self.relationships_schemas: - return generate_system_message_with_schemas() - return generate_system_message() - - def generate_prompt(self, data) -> str: - if self.nodes_schemas and self.relationships_schemas: - return generate_prompt_with_schemas( - data, self.nodes_schemas, self.relationships_schemas - ) - return generate_prompt(data) + def run(self, schema=None) -> Dict[str, List[Any]]: + prompt_token = self.llm.num_tokens_from_string(generate_extract_triple_prompt("", schema)) - def run(self, data: Dict) -> Dict[str, List[Any]]: - token_usage_per_prompt = self.llm.num_tokens_from_string( - self.generate_system_message() + self.generate_prompt("") - ) - chunked_data = split_string_to_fit_token_space( - llm=self.llm, string=self.text, token_use_per_string=token_usage_per_prompt + chunked_text = fit_token_space_by_split_text( + llm=self.llm, text=self.text, prompt_token=int(prompt_token) ) - results = [] - for chunk in chunked_data: - proceeded_chunk = self.process(chunk) - results.append(proceeded_chunk) - if self.spo: - results = get_spo_from_result(results) + result = {"vertices": [], "edges": [], "schema": schema} if schema else {"triples": []} + for chunk in chunked_text: + proceeded_chunk = self.extract_by_llm(schema, chunk) + print(f"[LLM] input: {chunk} \n output:{proceeded_chunk}") + if schema: + extract_by_regex_with_schema(schema, proceeded_chunk, result) else: - results = get_nodes_and_relationships_from_result(results) - return results + extract_by_regex(proceeded_chunk, result) + return result + + def extract_by_llm(self, schema, chunk): + prompt = generate_extract_triple_prompt(chunk, schema) + return self.llm.generate(prompt=prompt) diff --git a/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py b/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py new file mode 100644 index 000000000..6d002ff8c --- /dev/null +++ b/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from hugegraph_llm.operators.common_op.check_schema import CheckSchema + + +class TestCheckSchema(unittest.TestCase): + def setUp(self): + pass + + def test_schema_check_with_valid_input(self): + data = { + "vertices": [{"vertex_label": "person"}], + "edges": [ + { + "edge_label": "knows", + "source_vertex_label": "person", + "target_vertex_label": "person", + } + ], + } + check_schema = CheckSchema(data) + self.assertEqual(check_schema.run(), data) + + def test_schema_check_with_invalid_input(self): + data = "invalid input" + check_schema = CheckSchema(data) + with self.assertRaises(ValueError): + check_schema.run() + + def test_schema_check_with_missing_vertices(self): + data = { + "edges": [ + { + "edge_label": "knows", + "source_vertex_label": "person", + "target_vertex_label": "person", + } + ] + } + check_schema = CheckSchema(data) + with self.assertRaises(ValueError): + check_schema.run() + + def test_schema_check_with_missing_edges(self): + data = {"vertices": [{"vertex_label": "person"}]} + check_schema = CheckSchema(data) + with self.assertRaises(ValueError): + check_schema.run() + + def test_schema_check_with_invalid_vertices(self): + data = { + "vertices": "invalid vertices", + "edges": [ + { + "edge_label": "knows", + "source_vertex_label": "person", + "target_vertex_label": "person", + } + ], + } + check_schema = CheckSchema(data) + with self.assertRaises(ValueError): + check_schema.run() + + def test_schema_check_with_invalid_edges(self): + data = {"vertices": [{"vertex_label": "person"}], "edges": "invalid edges"} + check_schema = CheckSchema(data) + with self.assertRaises(ValueError): + check_schema.run() diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py new file mode 100644 index 000000000..350eeba2c --- /dev/null +++ b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from hugegraph_llm.llms.init_llm import LLMs +from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData + + +class TestDisambiguateData(unittest.TestCase): + def setUp(self): + self.triples = { + "triples": [ + (' "Alice "', ' "Age "', ' "25 "'), + (' "Alice "', ' "Profession "', ' "lawyer "'), + (' "Bob "', ' "Job "', ' "journalist "'), + (' "Alice "', ' "Roommate of "', ' "Bob "'), + (' "lucy "', "roommate", ' "Bob "'), + (' "Alice "', ' "is the ownner of "', ' "http://www.alice.com "'), + (' "Bob "', ' "Owns "', ' "http://www.bob.com "'), + ] + } + + self.triples_with_schema = { + "vertices": [ + { + "name": "Alice", + "label": "person", + "properties": {"name": "Alice", "age": "25", "occupation": "lawyer"}, + }, + { + "name": "Bob", + "label": "person", + "properties": {"name": "Bob", "occupation": "journalist"}, + }, + { + "name": "www.alice.com", + "label": "webpage", + "properties": {"name": "www.alice.com", "url": "www.alice.com"}, + }, + { + "name": "www.bob.com", + "label": "webpage", + "properties": {"name": "www.bob.com", "url": "www.bob.com"}, + }, + ], + "edges": [{"start": "Alice", "end": "Bob", "type": "roommate", "properties": {}}], + "schema": { + "vertices": [ + {"vertex_label": "person", "properties": ["name", "age", "occupation"]}, + {"vertex_label": "webpage", "properties": ["name", "url"]}, + ], + "edges": [ + { + "edge_label": "roommate", + "source_vertex_label": "person", + "target_vertex_label": "person", + "properties": [], + } + ], + }, + } + self.llm = LLMs().get_llm() + # self.llm = None + self.disambiguate_data = DisambiguateData(self.llm) + + def test_run(self): + result = self.disambiguate_data.run(self.triples_with_schema) + self.assertEqual(result, self.triples_with_schema) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py new file mode 100644 index 000000000..596b0896f --- /dev/null +++ b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from hugegraph_llm.operators.llm_op.info_extract import ( + InfoExtract, + extract_by_regex_with_schema, + extract_by_regex, +) + + +class TestInfoExtract(unittest.TestCase): + def setUp(self): + self.schema = { + "vertices": [ + {"vertex_label": "person", "properties": ["name", "age", "occupation"]}, + {"vertex_label": "webpage", "properties": ["name", "url"]}, + ], + "edges": [ + { + "edge_label": "roommate", + "source_vertex_label": "person", + "target_vertex_label": "person", + "properties": [], + } + ], + } + # self.llm = LLMs().get_llm() + self.llm = None + self.info_extract = InfoExtract(self.llm, "text") + + self.llm_output = """ + {"id": "as-rymwkgbvqf", "object": "chat.completion", "created": 1706599975, + "result": "Based on the given graph schema and the extracted text, we can extract + the following triples:\n\n + 1. (Alice, name, Alice) - person\n + 2. (Alice, age, 25) - person\n + 3. (Alice, occupation, lawyer) - person\n + 4. (Bob, name, Bob) - person\n + 5. (Bob, occupation, journalist) - person\n + 6. (Alice, roommate, Bob) - roommate\n + 7. (www.alice.com, name, www.alice.com) - webpage\n + 8. (www.alice.com, url, www.alice.com) - webpage\n + 9. (www.bob.com, name, www.bob.com) - webpage\n + 10. (www.bob.com, url, www.bob.com) - webpage\n\n + However, the schema does not provide a direct relationship between people and + webpages they own. To establish such a relationship, we might need to introduce + a new edge label like \"owns\" or modify the schema accordingly. Assuming we + introduce a new edge label \"owns\", we can extract the following additional + triples:\n\n + 1. (Alice, owns, www.alice.com) - owns\n2. (Bob, owns, www.bob.com) - owns\n\n + Please note that the extraction of some triples, like the webpage name and URL, + might seem redundant since they are the same. However, + I included them to strictly follow the given format. In a real-world scenario, + such redundancy might be avoided or handled differently.", + "is_truncated": false, "need_clear_history": false, "finish_reason": "normal", + "usage": {"prompt_tokens": 221, "completion_tokens": 325, "total_tokens": 546}} + """ + + def test_extract_by_regex_with_schema(self): + graph = {"vertices": [], "edges": [], "schema": self.schema} + extract_by_regex_with_schema(self.schema, self.llm_output, graph) + self.assertEqual( + graph, + { + "vertices": [ + { + "name": "Alice", + "label": "person", + "properties": {"name": "Alice", "age": "25", "occupation": "lawyer"}, + }, + { + "name": "Bob", + "label": "person", + "properties": {"name": "Bob", "occupation": "journalist"}, + }, + { + "name": "www.alice.com", + "label": "webpage", + "properties": {"name": "www.alice.com", "url": "www.alice.com"}, + }, + { + "name": "www.bob.com", + "label": "webpage", + "properties": {"name": "www.bob.com", "url": "www.bob.com"}, + }, + ], + "edges": [{"start": "Alice", "end": "Bob", "type": "roommate", "properties": {}}], + "schema": { + "vertices": [ + {"vertex_label": "person", "properties": ["name", "age", "occupation"]}, + {"vertex_label": "webpage", "properties": ["name", "url"]}, + ], + "edges": [ + { + "edge_label": "roommate", + "source_vertex_label": "person", + "target_vertex_label": "person", + "properties": [], + } + ], + }, + }, + ) + + def test_extract_by_regex(self): + graph = {"triples": []} + extract_by_regex(self.llm_output, graph) + self.assertEqual( + graph, + { + "triples": [ + ("Alice", "name", "Alice"), + ("Alice", "age", "25"), + ("Alice", "occupation", "lawyer"), + ("Bob", "name", "Bob"), + ("Bob", "occupation", "journalist"), + ("Alice", "roommate", "Bob"), + ("www.alice.com", "name", "www.alice.com"), + ("www.alice.com", "url", "www.alice.com"), + ("www.bob.com", "name", "www.bob.com"), + ("www.bob.com", "url", "www.bob.com"), + ("Alice", "owns", "www.alice.com"), + ("Bob", "owns", "www.bob.com"), + ] + }, + ) From 7277bb559c9250e9f92f0ed7f1405e787cb4da0c Mon Sep 17 00:00:00 2001 From: simon824 Date: Fri, 2 Feb 2024 16:41:27 +0800 Subject: [PATCH 2/4] fix codestyle --- hugegraph-llm/README.md | 5 +++-- hugegraph-llm/examples/build_kg_test.py | 6 +++--- .../operators/common_op/check_schema.py | 14 +++++++------- .../operators/hugegraph_op/commit_to_hugegraph.py | 4 +--- .../operators/llm_op/disambiguate_data.py | 4 ++-- .../hugegraph_llm/operators/llm_op/info_extract.py | 12 ++++++------ .../operators/llm_op/test_disambiguate_data.py | 3 +-- .../tests/operators/llm_op/test_info_extract.py | 10 +++++----- 8 files changed, 28 insertions(+), 30 deletions(-) diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index cb8e3f43c..493b2e7e6 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -44,10 +44,11 @@ builder = KgBuilder(LLMs().get_llm()) ) ``` -2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance,a user-defined schema or an extraction result. The method `print_result` can be chained to print the result. +2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance, a user-defined schema or an extraction result. The method `print_result` can be chained to print the result. + ```python # Import schema from a HugeGraph instance -import_schema(from_hugegraph="talent_graph").print_result() +import_schema(from_hugegraph="xxx").print_result() # Import schema from an extraction result import_schema(from_extraction="xxx").print_result() # Import schema from user-defined schema diff --git a/hugegraph-llm/examples/build_kg_test.py b/hugegraph-llm/examples/build_kg_test.py index 1274ed101..3380c4295 100644 --- a/hugegraph-llm/examples/build_kg_test.py +++ b/hugegraph-llm/examples/build_kg_test.py @@ -50,9 +50,9 @@ ( builder - .import_schema(from_hugegraph="talent_graph").print_result() - # .import_schema(from_extraction="fefe").print_result().run() - # .import_schema(from_input=schema).print_result() + .import_schema(from_hugegraph="xxx").print_result() + # .import_schema(from_extraction="xxx").print_result() + # .import_schema(from_user_defined=xxx).print_result() .extract_triples(TEXT).print_result() .disambiguate_word_sense() .commit_to_hugegraph() diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index bd491bdfe..0228a976a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -25,21 +25,21 @@ def __init__(self, data): self.data = data def run(self, schema=None) -> Any: - data = self.data or schema - if not isinstance(data, dict): + schema = self.data or schema + if not isinstance(schema, dict): raise ValueError("Input data is not a dictionary.") - if "vertices" not in data or "edges" not in data: + if "vertices" not in schema or "edges" not in schema: raise ValueError("Input data does not contain 'vertices' or 'edges'.") - if not isinstance(data["vertices"], list) or not isinstance(data["edges"], list): + if not isinstance(schema["vertices"], list) or not isinstance(schema["edges"], list): raise ValueError("'vertices' or 'edges' in input data is not a list.") - for vertex in data["vertices"]: + for vertex in schema["vertices"]: if not isinstance(vertex, dict): raise ValueError("Vertex in input data is not a dictionary.") if "vertex_label" not in vertex: raise ValueError("Vertex in input data does not contain 'vertex_label'.") if not isinstance(vertex["vertex_label"], str): raise ValueError("'vertex_label' in vertex is not of correct type.") - for edge in data["edges"]: + for edge in schema["edges"]: if not isinstance(edge, dict): raise ValueError("Edge in input data is not a dictionary.") if ( @@ -60,4 +60,4 @@ def run(self, schema=None) -> Any: "'edge_label', 'source_vertex_label', 'target_vertex_label' " "in edge is not of correct type." ) - return data + return schema diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 62443c2fd..558a8ba84 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -104,9 +104,7 @@ def schema_free_mode(self, data): ).secondary().ifNotExist().create() for item in data: - s = item[0].strip() - p = item[1].strip() - o = item[2].strip() + s, p, o = (element.strip() for element in item) s_id = self.client.graph().addVertex("vertex", {"name": s}, id=s).id t_id = self.client.graph().addVertex("vertex", {"name": o}, id=o).id self.client.graph().addEdge("edge", s_id, t_id, {"name": p}) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index a08793428..d279e52db 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -19,7 +19,7 @@ from typing import Dict, List, Any from hugegraph_llm.llms.base import BaseLLM -from hugegraph_llm.operators.llm_op.info_extract import extract_by_regex +from hugegraph_llm.operators.llm_op.info_extract import extract_triples_by_regex def generate_disambiguate_prompt(triples): @@ -48,6 +48,6 @@ def run(self, data: Dict) -> Dict[str, List[Any]]: prompt = generate_disambiguate_prompt(triples) llm_output = self.llm.generate(prompt=prompt) data = {"triples": []} - extract_by_regex(llm_output, data) + extract_triples_by_regex(llm_output, data) print(f"LLM input:{prompt} \n output: {llm_output} \n data: {data}") return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 18b4083ee..ab35da2b4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -65,13 +65,13 @@ def fit_token_space_by_split_text(llm: BaseLLM, text: str, prompt_token: int) -> return combined_chunks -def extract_by_regex(text, triples): +def extract_triples_by_regex(text, triples): text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") pattern = r"\((.*?), (.*?), (.*?)\)" triples["triples"] += re.findall(pattern, text) -def extract_by_regex_with_schema(schema, text, graph): +def extract_triples_by_regex_with_schema(schema, text, graph): text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" matches = re.findall(pattern, text) @@ -113,14 +113,14 @@ def run(self, schema=None) -> Dict[str, List[Any]]: result = {"vertices": [], "edges": [], "schema": schema} if schema else {"triples": []} for chunk in chunked_text: - proceeded_chunk = self.extract_by_llm(schema, chunk) + proceeded_chunk = self.extract_triples_by_llm(schema, chunk) print(f"[LLM] input: {chunk} \n output:{proceeded_chunk}") if schema: - extract_by_regex_with_schema(schema, proceeded_chunk, result) + extract_triples_by_regex_with_schema(schema, proceeded_chunk, result) else: - extract_by_regex(proceeded_chunk, result) + extract_triples_by_regex(proceeded_chunk, result) return result - def extract_by_llm(self, schema, chunk): + def extract_triples_by_llm(self, schema, chunk): prompt = generate_extract_triple_prompt(chunk, schema) return self.llm.generate(prompt=prompt) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py index 350eeba2c..ba800ce10 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py @@ -74,8 +74,7 @@ def setUp(self): ], }, } - self.llm = LLMs().get_llm() - # self.llm = None + self.llm = None self.disambiguate_data = DisambiguateData(self.llm) def test_run(self): diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py index 596b0896f..cd84521ea 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py @@ -19,8 +19,8 @@ from hugegraph_llm.operators.llm_op.info_extract import ( InfoExtract, - extract_by_regex_with_schema, - extract_by_regex, + extract_triples_by_regex_with_schema, + extract_triples_by_regex, ) @@ -40,7 +40,7 @@ def setUp(self): } ], } - # self.llm = LLMs().get_llm() + self.llm = None self.info_extract = InfoExtract(self.llm, "text") @@ -74,7 +74,7 @@ def setUp(self): def test_extract_by_regex_with_schema(self): graph = {"vertices": [], "edges": [], "schema": self.schema} - extract_by_regex_with_schema(self.schema, self.llm_output, graph) + extract_triples_by_regex_with_schema(self.schema, self.llm_output, graph) self.assertEqual( graph, { @@ -120,7 +120,7 @@ def test_extract_by_regex_with_schema(self): def test_extract_by_regex(self): graph = {"triples": []} - extract_by_regex(self.llm_output, graph) + extract_triples_by_regex(self.llm_output, graph) self.assertEqual( graph, { From bbf723eaaf388be06f9a0a83f43a876af4b6618d Mon Sep 17 00:00:00 2001 From: simon824 Date: Fri, 2 Feb 2024 16:43:26 +0800 Subject: [PATCH 3/4] fix codestyle --- hugegraph-llm/examples/build_kg_test.py | 7 ++++--- .../src/tests/operators/llm_op/test_disambiguate_data.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hugegraph-llm/examples/build_kg_test.py b/hugegraph-llm/examples/build_kg_test.py index 3380c4295..b0b8c5167 100644 --- a/hugegraph-llm/examples/build_kg_test.py +++ b/hugegraph-llm/examples/build_kg_test.py @@ -49,11 +49,12 @@ } ( - builder - .import_schema(from_hugegraph="xxx").print_result() + builder.import_schema(from_hugegraph="xxx") + .print_result() # .import_schema(from_extraction="xxx").print_result() # .import_schema(from_user_defined=xxx).print_result() - .extract_triples(TEXT).print_result() + .extract_triples(TEXT) + .print_result() .disambiguate_word_sense() .commit_to_hugegraph() .run() diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py index ba800ce10..fb58ec38e 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py @@ -17,7 +17,6 @@ import unittest -from hugegraph_llm.llms.init_llm import LLMs from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData From eca9f7f46adc2cf0d34c41b45eaf41e49232a56c Mon Sep 17 00:00:00 2001 From: Simon Cheung Date: Tue, 20 Feb 2024 18:34:33 +0800 Subject: [PATCH 4/4] Update test_disambiguate_data.py --- .../src/tests/operators/llm_op/test_disambiguate_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py index fb58ec38e..04ee42142 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py @@ -78,4 +78,4 @@ def setUp(self): def test_run(self): result = self.disambiguate_data.run(self.triples_with_schema) - self.assertEqual(result, self.triples_with_schema) + self.assertDictEqual(result, self.triples_with_schema)