diff --git a/.wordlist.txt b/.wordlist.txt index c2a21baf..593d0568 100644 --- a/.wordlist.txt +++ b/.wordlist.txt @@ -27,3 +27,4 @@ www faq Ollama ollama +Cypher \ No newline at end of file diff --git a/README.md b/README.md index 6e9f6649..fcc3804c 100644 --- a/README.md +++ b/README.md @@ -170,9 +170,9 @@ After generating the initial ontology, you can review it and make any necessary Once you are satisfied with the ontology, you can proceed to use it for creating and managing your Knowledge Graph (KG). -### Knowledge Graph +### Knowledge Graph Agent -Now, you can use the SDK to create a Knowledge Graph (KG) from your sources and ontology. +Now, you can use the SDK to create a Knowledge Graph (KG) from your sources and ontology for Q&A. ```python # After approving the ontology, load it from disk. @@ -191,6 +191,45 @@ kg.process_sources(sources) You can update the KG at any time by processing more sources with the `process_sources` method. +#### Configurable Prompts +When creating your Knowledge Graph (KG) agent, you can customize the prompts to tailor its behavior. This step is optional but can enhance functionality. + +There are five types of prompts: + +1. **`cypher_system_instruction`** + - System instructions for the Cypher generation step. + - **Note:** Ensure your prompt includes `{ontology}`. + +2. **`qa_system_instruction`** + - System instructions for the Q&A step. + +3. **`cypher_gen_prompt`** + - The prompt used during the Cypher generation step. + - **Note:** Include `{question}` in your prompt. + +4. **`cypher_gen_prompt_history`** + - The prompt for Cypher generation when history needs to be considered. + - **Note:** Include `{question}` and `{last_answer}` in your prompt. + +5. **`qa_prompt`** + - The prompt used during the Q&A step. + - **Note:** Include `{question}`, `{context}`, and `{cypher}` in your prompt. + +Here’s an example configuration: + +```python +kg = KnowledgeGraph( + name="kg_name", + model_config=KnowledgeGraphModelConfig.with_model(model), + ontology=ontology, + cypher_system_instruction=cypher_system_instruction, + qa_system_instruction=qa_system_instruction, + cypher_gen_prompt=cypher_gen_prompt, + cypher_gen_prompt_history=cypher_gen_prompt_history, + qa_prompt=qa_prompt +) +``` + ### Graph RAG At this point, you have a Knowledge Graph that can be queried using this SDK. Use the method `chat_session` for start a conversation. diff --git a/examples/movies/demo-movies.ipynb b/examples/movies/demo-movies.ipynb index ecbec637..440d91e1 100644 --- a/examples/movies/demo-movies.ipynb +++ b/examples/movies/demo-movies.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -29,7 +29,7 @@ "True" ] }, - "execution_count": 11, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -136,17 +136,17 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The directors of the movie \"The Matrix\" are Lilly Wachowski and Lana Wachowski.\n", - "The directors, Lilly Wachowski and Lana Wachowski, are connected to Keanu Reeves through their collaboration on the movies \"The Matrix,\" \"The Matrix Reloaded,\" and \"The Matrix Revolutions,\" in which Keanu Reeves acted.\n", - "The director of the movie \"Side by Side\" is Chris Kenneally.\n", - "The directors, ordered lexically, are Chris Kenneally, Lana Wachowski, and Lilly Wachowski.\n" + "{'question': 'Who is the director of the movie The Matrix?', 'response': 'The directors of the movie \"The Matrix\" are Lilly Wachowski and Lana Wachowski.', 'context': '[[\\'(:Movie{box_office_gross_usa:0,duration:\"2h 18m\",genre:\"Sci-Fi, Action\",popcornmeter_score:72,rating:\"R\",release_date:\"2003-05-15\",runtime:\"2h 18m\",title:\"The Matrix Reloaded\",tomatometer_score:74})\\', \\'(:Person{name:\"Lilly Wachowski\"})\\'], [\\'(:Movie{box_office_gross_usa:0,duration:\"2h 18m\",genre:\"Sci-Fi, Action\",popcornmeter_score:72,rating:\"R\",release_date:\"2003-05-15\",runtime:\"2h 18m\",title:\"The Matrix Reloaded\",tomatometer_score:74})\\', \\'(:Person{name:\"Lana Wachowski\"})\\'], [\\'(:Movie{box_office_gross_usa:0,duration:\"2h 16m\",genre:\"Sci-Fi, Action, Mystery & Thriller\",popcornmeter_score:85,rating:\"R\",release_date:\"1999-03-31\",title:\"The Matrix\",tomatometer_score:83})\\', \\'(:Person{name:\"Lilly Wachowski\"})\\'], [\\'(:Movie{box_office_gross_usa:0,duration:\"2h 16m\",genre:\"Sci-Fi, Action, Mystery & Thriller\",popcornmeter_score:85,rating:\"R\",release_date:\"1999-03-31\",title:\"The Matrix\",tomatometer_score:83})\\', \\'(:Person{name:\"Lana Wachowski\"})\\']]', 'cypher': \"\\nMATCH (m:Movie)-[:DIRECTED_BY]->(p:Person)\\nWHERE m.title CONTAINS 'The Matrix'\\nRETURN m, p\\n\"}\n", + "{'question': 'How this director connected to Keanu Reeves?', 'response': 'The directors Lilly Wachowski and Lana Wachowski are connected to Keanu Reeves through the movie \"The Matrix,\" where Keanu Reeves acted in the film.', 'context': '[[\\'(:Person{name:\"Lilly Wachowski\"})\\', \\'(:Movie{box_office_gross_usa:0,duration:\"2h 16m\",genre:\"Sci-Fi, Action, Mystery & Thriller\",popcornmeter_score:85,rating:\"R\",release_date:\"1999-03-31\",title:\"The Matrix\",tomatometer_score:83})\\', \\'(:Person{name:\"Keanu Reeves\",role:\"John Wick\"})\\'], [\\'(:Person{name:\"Lana Wachowski\"})\\', \\'(:Movie{box_office_gross_usa:0,duration:\"2h 16m\",genre:\"Sci-Fi, Action, Mystery & Thriller\",popcornmeter_score:85,rating:\"R\",release_date:\"1999-03-31\",title:\"The Matrix\",tomatometer_score:83})\\', \\'(:Person{name:\"Keanu Reeves\",role:\"John Wick\"})\\']]', 'cypher': \"\\nMATCH (p:Person)-[:DIRECTED]->(m:Movie)<-[:ACTED_IN]-(k:Person)\\nWHERE p.name IN ['Lilly Wachowski', 'Lana Wachowski'] AND k.name CONTAINS 'Keanu Reeves'\\nRETURN p, m, k\\n\"}\n", + "{'question': 'Who is the director of the movie Side by Side?', 'response': 'The director of the movie \"Side by Side\" is Chris Kenneally.', 'context': '[[\\'(:Movie{box_office_gross_usa:0,duration:\"1h 38m\",genre:\"Documentary\",popcornmeter_score:86,rating:\"Not Rated\",release_date:\"2012-08-17\",runtime:\"1h 38m\",title:\"Side by Side\",tomatometer_score:92})\\', \\'(:Person{name:\"Chris Kenneally\"})\\']]', 'cypher': \"\\nMATCH (m:Movie)-[:DIRECTED_BY]->(p:Person)\\nWHERE m.title CONTAINS 'Side by Side'\\nRETURN m, p\\n\"}\n", + "{'question': 'Order the directors that you mentioned in all of our conversation by lexical order.', 'response': 'The directors, ordered lexically, are Chris Kenneally, Lana Wachowski, and Lilly Wachowski.', 'context': '[[\\'(:Person{name:\"Chris Kenneally\"})\\'], [\\'(:Person{name:\"Lana Wachowski\"})\\'], [\\'(:Person{name:\"Lilly Wachowski\"})\\']]', 'cypher': \"\\nMATCH (p:Person)\\nWHERE p.name IN ['Lilly Wachowski', 'Lana Wachowski', 'Chris Kenneally']\\nRETURN p\\nORDER BY p.name\\n\"}\n" ] } ], diff --git a/examples/ufc/demo-ufc.ipynb b/examples/ufc/demo-ufc.ipynb index 51728e8e..43eca283 100644 --- a/examples/ufc/demo-ufc.ipynb +++ b/examples/ufc/demo-ufc.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -29,7 +29,7 @@ "True" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -62,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -152,73 +152,67 @@ " \"description\": \"\"\n", " },\n", " {\n", - " \"label\": \"Fight\",\n", + " \"label\": \"FightEvent\",\n", " \"attributes\": [\n", " {\n", - " \"name\": \"date\",\n", + " \"name\": \"name\",\n", " \"type\": \"string\",\n", " \"unique\": true,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"location\",\n", + " \"name\": \"date\",\n", " \"type\": \"string\",\n", " \"unique\": false,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"method\",\n", + " \"name\": \"location\",\n", " \"type\": \"string\",\n", " \"unique\": false,\n", " \"required\": true\n", - " },\n", + " }\n", + " ],\n", + " \"description\": \"\"\n", + " },\n", + " {\n", + " \"label\": \"Bout\",\n", + " \"attributes\": [\n", " {\n", - " \"name\": \"round\",\n", - " \"type\": \"number\",\n", + " \"name\": \"weight_class\",\n", + " \"type\": \"string\",\n", " \"unique\": false,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"time\",\n", + " \"name\": \"method\",\n", " \"type\": \"string\",\n", " \"unique\": false,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"time_format\",\n", - " \"type\": \"string\",\n", + " \"name\": \"round\",\n", + " \"type\": \"number\",\n", " \"unique\": false,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"referee\",\n", + " \"name\": \"time\",\n", " \"type\": \"string\",\n", " \"unique\": false,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"details\",\n", + " \"name\": \"time_format\",\n", " \"type\": \"string\",\n", " \"unique\": false,\n", - " \"required\": false\n", - " }\n", - " ],\n", - " \"description\": \"\"\n", - " },\n", - " {\n", - " \"label\": \"Event\",\n", - " \"attributes\": [\n", - " {\n", - " \"name\": \"title\",\n", - " \"type\": \"string\",\n", - " \"unique\": true,\n", " \"required\": true\n", " },\n", " {\n", - " \"name\": \"date\",\n", + " \"name\": \"referee\",\n", " \"type\": \"string\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " }\n", " ],\n", " \"description\": \"\"\n", @@ -226,12 +220,12 @@ " ],\n", " \"relations\": [\n", " {\n", - " \"label\": \"PARTICIPATED_IN\",\n", + " \"label\": \"FOUGHT_IN\",\n", " \"source\": {\n", " \"label\": \"Person\"\n", " },\n", " \"target\": {\n", - " \"label\": \"Fight\"\n", + " \"label\": \"Bout\"\n", " },\n", " \"attributes\": [\n", " {\n", @@ -241,68 +235,96 @@ " \"required\": true\n", " },\n", " {\n", - " \"name\": \"total_knockdowns\",\n", + " \"name\": \"knockdowns\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " },\n", " {\n", " \"name\": \"significant_strikes\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " },\n", " {\n", - " \"name\": \"significant_strike_percentage\",\n", + " \"name\": \"significant_strikes_attempted\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", + " },\n", + " {\n", + " \"name\": \"significant_strikes_percentage\",\n", + " \"type\": \"string\",\n", + " \"unique\": false,\n", + " \"required\": false\n", " },\n", " {\n", " \"name\": \"total_strikes\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", + " },\n", + " {\n", + " \"name\": \"total_strikes_attempted\",\n", + " \"type\": \"number\",\n", + " \"unique\": false,\n", + " \"required\": false\n", " },\n", " {\n", " \"name\": \"takedowns\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " },\n", " {\n", - " \"name\": \"takedown_percentage\",\n", + " \"name\": \"takedowns_attempted\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", + " },\n", + " {\n", + " \"name\": \"takedown_percentage\",\n", + " \"type\": \"string\",\n", + " \"unique\": false,\n", + " \"required\": false\n", " },\n", " {\n", " \"name\": \"submissions_attempted\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " },\n", " {\n", " \"name\": \"passes\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " },\n", " {\n", " \"name\": \"reversals\",\n", " \"type\": \"number\",\n", " \"unique\": false,\n", - " \"required\": true\n", + " \"required\": false\n", " }\n", " ]\n", " },\n", " {\n", - " \"label\": \"ORGANIZED_AT\",\n", + " \"label\": \"PART_OF_EVENT\",\n", " \"source\": {\n", - " \"label\": \"Fight\"\n", + " \"label\": \"Bout\"\n", " },\n", " \"target\": {\n", - " \"label\": \"Event\"\n", + " \"label\": \"FightEvent\"\n", + " },\n", + " \"attributes\": []\n", + " },\n", + " {\n", + " \"label\": \"PART_OF\",\n", + " \"source\": {\n", + " \"label\": \"Bout\"\n", + " },\n", + " \"target\": {\n", + " \"label\": \"FightEvent\"\n", " },\n", " \"attributes\": []\n", " }\n", @@ -326,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -354,15 +376,15 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Waldo Cortes-Acosta is known as Salsa Boy.\n", - "One of Salsa Boy's fights took place on August 26, 2023, in Kallang, Singapore. In this fight, he achieved a win by KO/TKO through punches to the head at distance. The fight was officiated by referee Mark Craig and ended in the first round at 3:01.\n" + "{'question': 'Who is Salsa Boy?', 'response': 'Waldo Cortes-Acosta is Salsa Boy.', 'context': '[[\\'(:Person{name:\"Waldo Cortes-Acosta\",nickname:\"Salsa Boy\"})\\']]', 'cypher': \"MATCH (p:Person) WHERE p.nickname CONTAINS 'Salsa Boy' RETURN p\"}\n", + "{'question': 'Tell me about one of his fights?', 'response': 'Waldo Cortes-Acosta, also known as Salsa Boy, fought in a bout at UFC 284: Makhachev vs. Volkanovski, held on February 11, 2023, in Perth, Western Australia, Australia. He won the fight by submission in the second round, at 4:05. In this fight, he displayed impressive skills with 52 significant strikes out of 101 attempted, achieving a 51.49% success rate. He also attempted 4 takedowns, successfully completing 3 for a 75.00% success rate, and attempted 1 submission.', 'context': '[[\\'(:Person{name:\"Waldo Cortes-Acosta\",nickname:\"Salsa Boy\"})\\', \\'()-[:FOUGHT_IN{knockdowns:0,passes:0,result:\"Win\",reversals:0,significant_strikes:52,significant_strikes_attempted:101,significant_strikes_percentage:\"51.49%\",submissions_attempted:1,takedown_percentage:\"75.00%\",takedowns:3,takedowns_attempted:4,total_strikes:130,total_strikes_attempted:204}]->()\\', \\'(:Bout{method:\"Submission\",referee:\"Osiris Maia\",round:2,time:\"4:05\",time_format:\"3 Rounds (5-5-5)\",weight_class:\"Lightweight\"})\\', \\'(:FightEvent{date:\"2023-02-11\",location:\"Perth, Western Australia, Australia\",name:\"UFC 284: Makhachev vs. Volkanovski\"})\\']]', 'cypher': \"MATCH (p:Person)-[f:FOUGHT_IN]->(b:Bout)-[:PART_OF_EVENT]->(e:FightEvent) WHERE p.nickname CONTAINS 'Salsa Boy' RETURN p, f, b, e LIMIT 1\"}\n" ] } ], diff --git a/graphrag_sdk/agents/kg_agent.py b/graphrag_sdk/agents/kg_agent.py index 0a3ea1de..908303ac 100644 --- a/graphrag_sdk/agents/kg_agent.py +++ b/graphrag_sdk/agents/kg_agent.py @@ -136,7 +136,7 @@ def run(self, params: dict) -> str: """ output = self.chat_session.send_message(params["prompt"]) - return output + return output['response'] def __repr__(self): """ diff --git a/graphrag_sdk/chat_session.py b/graphrag_sdk/chat_session.py index 6e392804..ea1676d6 100644 --- a/graphrag_sdk/chat_session.py +++ b/graphrag_sdk/chat_session.py @@ -1,9 +1,8 @@ +from falkordb import Graph from graphrag_sdk.ontology import Ontology +from graphrag_sdk.steps.qa_step import QAStep from graphrag_sdk.model_config import KnowledgeGraphModelConfig from graphrag_sdk.steps.graph_query_step import GraphQueryGenerationStep -from graphrag_sdk.steps.qa_step import QAStep -from graphrag_sdk.fixtures.prompts import GRAPH_QA_SYSTEM, CYPHER_GEN_SYSTEM -from falkordb import Graph class ChatSession: @@ -25,7 +24,9 @@ class ChatSession: >>> chat_session.send_message("What is the capital of France?") """ - def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph): + def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph, + cypher_system_instruction: str, qa_system_instruction: str, + cypher_gen_prompt: str, qa_prompt: str, cypher_gen_prompt_history: str): """ Initializes a new ChatSession object. @@ -44,14 +45,21 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, self.model_config = model_config self.graph = graph self.ontology = ontology + cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json())) + + + self.cypher_prompt = cypher_gen_prompt + self.qa_prompt = qa_prompt + self.cypher_prompt_with_history = cypher_gen_prompt_history + self.cypher_chat_session = ( model_config.cypher_generation.with_system_instruction( - CYPHER_GEN_SYSTEM.replace("#ONTOLOGY", str(ontology.to_json())) + cypher_system_instruction ).start_chat() ) self.qa_chat_session = model_config.qa.with_system_instruction( - GRAPH_QA_SYSTEM - ).start_chat() + qa_system_instruction + ).start_chat() self.last_answer = None def send_message(self, message: str): @@ -62,24 +70,42 @@ def send_message(self, message: str): message (str): The message to send. Returns: - str: The response to the message. + dict: The response to the message in the following format: + {"question": message, + "response": answer, + "context": context, + "cypher": cypher} """ cypher_step = GraphQueryGenerationStep( graph=self.graph, chat_session=self.cypher_chat_session, ontology=self.ontology, last_answer=self.last_answer, + cypher_prompt=self.cypher_prompt, + cypher_prompt_with_history=self.cypher_prompt_with_history ) (context, cypher) = cypher_step.run(message) if not cypher or len(cypher) == 0: - return "I am sorry, I could not find the answer to your question" + return { + "question": message, + "response": "I am sorry, I could not find the answer to your question", + "context": None, + "cypher": None + } qa_step = QAStep( chat_session=self.qa_chat_session, + qa_prompt=self.qa_prompt, ) answer = qa_step.run(message, cypher, context) self.last_answer = answer - return answer \ No newline at end of file + + return { + "question": message, + "response": answer, + "context": context, + "cypher": cypher + } \ No newline at end of file diff --git a/graphrag_sdk/fixtures/prompts.py b/graphrag_sdk/fixtures/prompts.py index 1fde8768..929af0af 100644 --- a/graphrag_sdk/fixtures/prompts.py +++ b/graphrag_sdk/fixtures/prompts.py @@ -429,7 +429,7 @@ Single-Source minimal-weight paths: Find minimal-weight paths from a given source entity using algo.SSpaths(). Ontology: -#ONTOLOGY +{ontology} For example, given the question "Which managers own Neo4j stocks?", the OpenCypher statement should look like this: @@ -442,7 +442,7 @@ CYPHER_GEN_PROMPT = """ Using the ontology provided, generate an OpenCypher statement to query the graph database returning all relevant entities, relationships, and attributes to answer the question below. -If you cannot generate a OpenCypher statement for any reason, return an empty string. +If you cannot generate a OpenCypher statement for any reason, return an empty response. Respect the order of the relationships, the arrows should always point from the "source" to the "target". Please think if your answer is a valid Cypher query, and correct it if it is not. @@ -482,8 +482,8 @@ The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. Make the answer sound as a response to the question. Do not mention that you based the result on the given information. Do not answer more than the question asks for. -Here is an example: +Here is an example: Question: Which managers own Neo4j stocks? Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. diff --git a/graphrag_sdk/kg.py b/graphrag_sdk/kg.py index 42562fa1..831a05d3 100644 --- a/graphrag_sdk/kg.py +++ b/graphrag_sdk/kg.py @@ -1,16 +1,15 @@ import logging -from graphrag_sdk.ontology import Ontology +import warnings from falkordb import FalkorDB +from graphrag_sdk.ontology import Ontology from graphrag_sdk.source import AbstractSource -from graphrag_sdk.model_config import KnowledgeGraphModelConfig -from graphrag_sdk.steps.extract_data_step import ExtractDataStep -from graphrag_sdk.steps.graph_query_step import GraphQueryGenerationStep -from graphrag_sdk.fixtures.prompts import GRAPH_QA_SYSTEM, CYPHER_GEN_SYSTEM -from graphrag_sdk.steps.qa_step import QAStep from graphrag_sdk.chat_session import ChatSession -from graphrag_sdk.helpers import map_dict_to_cypher_properties from graphrag_sdk.attribute import AttributeType, Attribute -from graphrag_sdk.models import GenerativeModelChatSession +from graphrag_sdk.helpers import map_dict_to_cypher_properties +from graphrag_sdk.model_config import KnowledgeGraphModelConfig +from graphrag_sdk.steps.extract_data_step import ExtractDataStep +from graphrag_sdk.fixtures.prompts import (GRAPH_QA_SYSTEM, CYPHER_GEN_SYSTEM, + CYPHER_GEN_PROMPT, GRAPH_QA_PROMPT, CYPHER_GEN_PROMPT_WITH_HISTORY) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -31,6 +30,11 @@ def __init__( port: int = 6379, username: str | None = None, password: str | None = None, + cypher_system_instruction: str = None, + qa_system_instruction: str = None, + cypher_gen_prompt: str = None, + qa_prompt: str = None, + cypher_gen_prompt_history: str = None, ): """ Initialize Knowledge Graph @@ -43,6 +47,11 @@ def __init__( username (str|None): FalkorDB username. password (str|None): FalkorDB password. ontology (Ontology|None): Ontology to use. + cypher_system_instruction (str|None): Cypher system instruction. Make sure you have {ontology} in the instruction. + qa_system_instruction (str|None): QA system instruction. + cypher_gen_prompt (str|None): Cypher generation prompt. Make sure you have {question} in the prompt. + qa_prompt (str|None): QA prompt. Make sure you have {question}, {context} and {cypher} in the prompt. + cypher_gen_prompt_history (str|None): Cypher generation prompt with history. Make sure you have {question} and {last_answer} in the prompt. """ if not isinstance(name, str) or name == "": @@ -57,6 +66,44 @@ def __init__( self._model_config = model_config self.sources = set([]) + if cypher_system_instruction is None: + cypher_system_instruction = CYPHER_GEN_SYSTEM + else: + if "{ontology}" not in cypher_system_instruction: + warnings.warn("Cypher system instruction should contain {ontology}", category=UserWarning) + + if qa_system_instruction is None: + qa_system_instruction = GRAPH_QA_SYSTEM + + if cypher_gen_prompt is None: + cypher_gen_prompt = CYPHER_GEN_PROMPT + else: + if "{question}" not in cypher_gen_prompt: + raise Exception("Cypher generation prompt should contain {question}") + + if qa_prompt is None: + qa_prompt = GRAPH_QA_PROMPT + else: + if "{question}" not in qa_prompt or "{context}" not in qa_prompt: + raise Exception("QA prompt should contain {question} and {context}") + if "{cypher}" not in qa_prompt: + warnings.warn("QA prompt should contain {cypher}", category=UserWarning) + + if cypher_gen_prompt_history is None: + cypher_gen_prompt_history = CYPHER_GEN_PROMPT_WITH_HISTORY + else: + if "{question}" not in cypher_gen_prompt_history: + raise Exception("Cypher generation prompt with history should contain {question}") + if "{last_answer}" not in cypher_gen_prompt_history: + warnings.warn("Cypher generation prompt with history should contain {last_answer}", category=UserWarning) + + # Assign the validated values + self.cypher_system_instruction = cypher_system_instruction + self.qa_system_instruction = qa_system_instruction + self.cypher_gen_prompt = cypher_gen_prompt + self.qa_prompt = qa_prompt + self.cypher_gen_prompt_history = cypher_gen_prompt_history + # Attributes @property @@ -135,8 +182,9 @@ def delete(self) -> None: setattr(self, key, None) def chat_session(self) -> ChatSession: - return ChatSession(self._model_config, self.ontology, self.graph) - + chat_session = ChatSession(self._model_config, self.ontology, self.graph, self.cypher_system_instruction, + self.qa_system_instruction, self.cypher_gen_prompt, self.qa_prompt, self.cypher_gen_prompt_history) + return chat_session def add_node(self, entity: str, attributes: dict): """ Add a node to the knowledge graph, checking if it matches the ontology diff --git a/graphrag_sdk/steps/graph_query_step.py b/graphrag_sdk/steps/graph_query_step.py index 96d8de79..a9ac8bb2 100644 --- a/graphrag_sdk/steps/graph_query_step.py +++ b/graphrag_sdk/steps/graph_query_step.py @@ -1,21 +1,15 @@ +import logging +from falkordb import Graph from graphrag_sdk.steps.Step import Step from graphrag_sdk.ontology import Ontology from graphrag_sdk.models import ( GenerativeModelChatSession, ) -from graphrag_sdk.fixtures.prompts import ( - CYPHER_GEN_SYSTEM, - CYPHER_GEN_PROMPT, - CYPHER_GEN_PROMPT_WITH_ERROR, - CYPHER_GEN_PROMPT_WITH_HISTORY, -) -import logging from graphrag_sdk.helpers import ( extract_cypher, validate_cypher, stringify_falkordb_response, ) -from falkordb import Graph logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -33,22 +27,26 @@ def __init__( chat_session: GenerativeModelChatSession, config: dict = None, last_answer: str = None, + cypher_prompt: str = None, + cypher_prompt_with_history: str = None, ) -> None: self.ontology = ontology self.config = config or {} self.graph = graph self.chat_session = chat_session self.last_answer = last_answer + self.cypher_prompt = cypher_prompt + self.cypher_prompt_with_history = cypher_prompt_with_history def run(self, question: str, retries: int = 10): cypher = "" for i in range(retries): try: cypher_prompt = ( - (CYPHER_GEN_PROMPT.format(question=question) + (self.cypher_prompt.format(question=question) if self.last_answer is None - else CYPHER_GEN_PROMPT_WITH_HISTORY.format(question=question, last_answer=self.last_answer)) - ) + else self.cypher_prompt_with_history.format(question=question, last_answer=self.last_answer)) + ) logger.debug(f"Cypher Prompt: {cypher_prompt}") cypher_statement_response = self.chat_session.send_message( cypher_prompt, diff --git a/graphrag_sdk/steps/qa_step.py b/graphrag_sdk/steps/qa_step.py index 45e4624b..286cfd8b 100644 --- a/graphrag_sdk/steps/qa_step.py +++ b/graphrag_sdk/steps/qa_step.py @@ -1,9 +1,7 @@ +import logging from graphrag_sdk.steps.Step import Step from graphrag_sdk.models import GenerativeModelChatSession -from graphrag_sdk.fixtures.prompts import GRAPH_QA_SYSTEM, GRAPH_QA_PROMPT -import logging - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -17,17 +15,18 @@ def __init__( self, chat_session: GenerativeModelChatSession, config: dict = None, + qa_prompt: str = None, ) -> None: self.config = config or {} self.chat_session = chat_session + self.qa_prompt = qa_prompt def run(self, question: str, cypher: str, context: str): - - qa_prompt = GRAPH_QA_PROMPT.format( + qa_prompt = self.qa_prompt.format( context=context, cypher=cypher, question=question ) - # logger.debug(f"QA Prompt: {qa_prompt}") + logger.debug(f"QA Prompt: {qa_prompt}") qa_response = self.chat_session.send_message(qa_prompt) return qa_response.text diff --git a/tests/test_kg_gemini.py b/tests/test_kg_gemini.py index 8778fe91..cbfdf2b1 100644 --- a/tests/test_kg_gemini.py +++ b/tests/test_kg_gemini.py @@ -91,6 +91,7 @@ def test_kg_creation(self): chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") + answer = answer['response'] logger.info(f"Answer: {answer}") diff --git a/tests/test_kg_ollama.py b/tests/test_kg_ollama.py index 52429e27..06ec03ee 100644 --- a/tests/test_kg_ollama.py +++ b/tests/test_kg_ollama.py @@ -91,6 +91,7 @@ def test_kg_creation(self): chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") + answer = answer['response'] logger.info(f"Answer: {answer}") diff --git a/tests/test_kg_openai.py b/tests/test_kg_openai.py index d7b25c4f..785a28df 100644 --- a/tests/test_kg_openai.py +++ b/tests/test_kg_openai.py @@ -85,6 +85,7 @@ def test_kg_creation(self): chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") + answer = answer['response'] logger.info(f"Answer: {answer}") diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index c445a382..02ae5978 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -1,25 +1,20 @@ +import logging +import unittest +from json import loads from dotenv import load_dotenv - -load_dotenv() -from graphrag_sdk.ontology import Ontology from graphrag_sdk.entity import Entity from graphrag_sdk.relation import Relation +from graphrag_sdk.ontology import Ontology +from graphrag_sdk.agents.kg_agent import KGAgent +from graphrag_sdk.orchestrator import Orchestrator from graphrag_sdk.attribute import Attribute, AttributeType -import unittest from graphrag_sdk.models.gemini import GeminiGenerativeModel from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig -from graphrag_sdk.orchestrator import Orchestrator -from graphrag_sdk.agents.kg_agent import KGAgent -import vertexai -import os -import logging -from json import loads +load_dotenv() logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -vertexai.init(project=os.getenv("PROJECT_ID"), location=os.getenv("REGION")) - class TestMultiAgent(unittest.TestCase):