Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/generate/generate_masked_fill_in_blank_qa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Generate Masked Fill-in-blank QAs
In this module, we generate fill-in-blank QAs from unstructured corpora by randomly masking core entities in a knowledge graph. The key is that a rule-based validator can automatically verify the answers to these questions. For example:
> **Question:** Hematogenous long-bone osteomyelitis is an infection of the bone, primarily affecting the long bones, and often results from blood-borne pathogens. This condition is characterized by several key symptoms, including ___ and swelling. ___ is a prominent symptom in both primary and recurrent cases of hematogenous long-bone osteomyelitis, manifesting as persistent discomfort in the affected area.
> **Answer:** pain

Because the answer of these questions can be easily verified, they are well-suited for RLVR (Reinforcement Learning with Verifiable Rewards).

For more details, please see our paper "Knowledge-to-Verification: Exploring RLVR for LLMs in Knowledge-Intensive Domains". It has been accepted to the ACL 2026 Main Conference, and we will update the link soon.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 -m graphgen.run \
--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
global_params:
working_dir: cache
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
op_name: read
type: source
dependencies: []
params:
input_path:
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples

- id: chunk_documents
op_name: chunk
type: map_batch
dependencies:
- read_files
execution_params:
replicas: 4
params:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting

- id: build_kg
op_name: build_kg
type: map_batch
dependencies:
- chunk_documents
execution_params:
replicas: 1
batch_size: 128

- id: partition
op_name: partition
type: aggregate
dependencies:
- build_kg
params:
method: quintuple

- id: generate
op_name: generate
type: map_batch
dependencies:
- partition
execution_params:
replicas: 1
batch_size: 128
save_output: true # save output
params:
method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa
data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs
6 changes: 6 additions & 0 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,10 @@ def format_generation_results(
{"role": "assistant", "content": answer},
]
}

if output_data_format == "QA_pairs":
return {
"question": question,
"answer": answer,
}
raise ValueError(f"Unknown output data format: {output_data_format}")
6 changes: 6 additions & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AtomicGenerator,
CoTGenerator,
FillInBlankGenerator,
MaskedFillInBlankGenerator,
MultiAnswerGenerator,
MultiChoiceGenerator,
MultiHopGenerator,
Expand All @@ -30,6 +31,8 @@
DFSPartitioner,
ECEPartitioner,
LeidenPartitioner,
QuintuplePartitioner,
TriplePartitioner,
)
from .reader import (
CSVReader,
Expand Down Expand Up @@ -73,6 +76,7 @@
"QuizGenerator": ".generator",
"TrueFalseGenerator": ".generator",
"VQAGenerator": ".generator",
"MaskedFillInBlankGenerator": ".generator",
# KG Builder
"LightRAGKGBuilder": ".kg_builder",
"MMKGBuilder": ".kg_builder",
Expand All @@ -86,6 +90,8 @@
"DFSPartitioner": ".partitioner",
"ECEPartitioner": ".partitioner",
"LeidenPartitioner": ".partitioner",
"TriplePartitioner": ".partitioner",
"QuintuplePartitioner": ".partitioner",
# Reader
"CSVReader": ".reader",
"JSONReader": ".reader",
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .quiz_generator import QuizGenerator
from .true_false_generator import TrueFalseGenerator
from .vqa_generator import VQAGenerator
from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
134 changes: 134 additions & 0 deletions graphgen/models/generator/masked_fill_in_blank_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import random
import re
from typing import Any, Optional

from graphgen.bases import BaseGenerator
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger

random.seed(42)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation, which can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.choice on line 103.



class MaskedFillInBlankGenerator(BaseGenerator):
"""
Masked Fill-in-blank Generator follows a TWO-STEP process:
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
"""

@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""
Build prompts for REPHRASE.
:param batch
:return:
"""
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relations_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relations_str)

# TODO: configure add_context
# if add_context:
# original_ids = [
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
# original_ids = list(set(original_ids))
# original_text = await text_chunks_storage.get_by_ids(original_ids)
# original_text = "\n".join(
# [
# f"{index + 1}. {text['content']}"
# for index, text in enumerate(original_text)
# ]
# )
Comment on lines +43 to +55
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out code seems to be related to a future feature (add_context). It's better to remove commented-out code from the codebase to improve readability. If this logic is needed for future reference, it should be tracked in an issue or a separate branch.

prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
entities=entities_str, relationships=relations_str
)
return prompt

@staticmethod
def parse_rephrased_text(response: str) -> Optional[str]:
"""
Parse the rephrased text from the response.
:param response:
:return: rephrased text
"""
rephrased_match = re.search(
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
)
if rephrased_match:
rephrased_text = rephrased_match.group(1).strip()
else:
logger.warning("Failed to parse rephrased text from response: %s", response)
return None
return rephrased_text.strip('"').strip("'")

@staticmethod
def parse_response(response: str) -> dict:
pass
Comment on lines +78 to +80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parse_response method is defined as an abstract method in the BaseGenerator class but is implemented with pass here. Additionally, the return type hint dict is incompatible with the base class's list[dict]. Since this method is not used in the overridden generate method, it should either be implemented correctly or raise NotImplementedError to adhere to the abstract base class contract.

Suggested change
@staticmethod
def parse_response(response: str) -> dict:
pass
@staticmethod
def parse_response(response: str) -> list[dict]:
raise NotImplementedError("This method is not used in MaskedFillInBlankGenerator as it overrides the `generate` method.")


async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> list[dict]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
rephrasing_prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(rephrasing_prompt)
context = self.parse_rephrased_text(response)
if not context:
return []

nodes, edges = batch

assert len(nodes) == 3, (
"MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, "
f"but got {len(nodes)} nodes."
)
assert len(edges) == 2, (
"MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, "
f"but got {len(edges)} edges."
)

node1, node2, node3 = nodes
mask_node = random.choice([node1, node2, node3])
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)

match = re.search(mask_pattern, context)
if match:
gth = match.group(0)
masked_context = mask_pattern.sub("___", context)
else:
logger.debug(
"Regex Match Failed!\n"
"Expected name of node: %s\n"
"Actual context: %s\n",
mask_node_name,
context,
)
return []

logger.debug("masked_context: %s", masked_context)
qa_pairs = {
"question": masked_context,
"answer": gth,
}
return [qa_pairs]
2 changes: 2 additions & 0 deletions graphgen/models/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from .dfs_partitioner import DFSPartitioner
from .ece_partitioner import ECEPartitioner
from .leiden_partitioner import LeidenPartitioner
from .quintuple_partitioner import QuintuplePartitioner
from .triple_partitioner import TriplePartitioner
74 changes: 74 additions & 0 deletions graphgen/models/partitioner/quintuple_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import random
from collections import deque
from typing import Any, Iterable, Set

from graphgen.bases import BaseGraphStorage, BasePartitioner
from graphgen.bases.datatypes import Community

random.seed(42)


class QuintuplePartitioner(BasePartitioner):
"""
quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node).
1. Automatically ignore isolated points.
2. In each connected component, yield quintuples in the order of BFS.
"""

def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
) -> Iterable[Community]:
nodes = [n[0] for n in g.get_all_nodes()]
random.shuffle(nodes)

visited_nodes: Set[str] = set()
used_edges: Set[frozenset[str]] = set()

for seed in nodes:
if seed in visited_nodes:
continue

# start BFS in a connected component
queue = deque([seed])
visited_nodes.add(seed)

while queue:
u = queue.popleft()

# collect all neighbors connected to node u via unused edges
available_neighbors = []
for v in g.get_neighbors(u):
edge_key = frozenset((u, v))
if edge_key not in used_edges:
available_neighbors.append(v)

# standard BFS queue maintenance
if v not in visited_nodes:
visited_nodes.add(v)
queue.append(v)

random.shuffle(available_neighbors)

# every two neighbors paired with the center node u creates one quintuple
# Note: If available_neighbors has an odd length, the remaining edge
# stays unused for now. It may be matched into a quintuple later
# when its other endpoint is processed as a center node.
for i in range(0, len(available_neighbors) // 2 * 2, 2):
v1 = available_neighbors[i]
v2 = available_neighbors[i + 1]

edge1 = frozenset((u, v1))
edge2 = frozenset((u, v2))

used_edges.add(edge1)
used_edges.add(edge2)

v1_s, v2_s = sorted((v1, v2))

yield Community(
id=f"{v1_s}-{u}-{v2_s}",
nodes=[v1_s, u, v2_s],
edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))],
)
58 changes: 58 additions & 0 deletions graphgen/models/partitioner/triple_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import random
from collections import deque
from typing import Any, Iterable, Set

from graphgen.bases import BaseGraphStorage, BasePartitioner
from graphgen.bases.datatypes import Community

random.seed(42)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation. This can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.shuffle.



class TriplePartitioner(BasePartitioner):
"""
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
1. Automatically ignore isolated points.
2. In each connected component, yield triples in the order of BFS.
"""

def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
) -> Iterable[Community]:
nodes = [n[0] for n in g.get_all_nodes()]
random.shuffle(nodes)

visited_nodes: Set[str] = set()
used_edges: Set[frozenset[str]] = set()

for seed in nodes:
if seed in visited_nodes:
continue

# start BFS in a connected component
queue = deque([seed])
visited_nodes.add(seed)

while queue:
u = queue.popleft()

for v in g.get_neighbors(u):
edge_key = frozenset((u, v))

# if this edge has not been used, a new triple has been found
if edge_key not in used_edges:
used_edges.add(edge_key)

# use the edge name to ensure the uniqueness of the ID
u_sorted, v_sorted = sorted((u, v))
yield Community(
id=f"{u_sorted}-{v_sorted}",
nodes=[u_sorted, v_sorted],
edges=[(u_sorted, v_sorted)],
)

# continue to BFS
if v not in visited_nodes:
visited_nodes.add(v)
queue.append(v)
Loading
Loading