Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/generate/generate_masked_fill_in_blank_qa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Generate Masked Fill-in-blank QAs
# TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The README file currently only contains "TODO". To make this example useful for other developers and users, please add a brief description of what this feature does, what the configuration options mean, and how to run it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 -m graphgen.run \
--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
global_params:
working_dir: cache
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
op_name: read
type: source
dependencies: []
params:
input_path:
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples

- id: chunk_documents
op_name: chunk
type: map_batch
dependencies:
- read_files
execution_params:
replicas: 4
params:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting

- id: build_kg
op_name: build_kg
type: map_batch
dependencies:
- chunk_documents
execution_params:
replicas: 1
batch_size: 128

- id: partition
op_name: partition
type: aggregate
dependencies:
- build_kg
params:
method: triple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace character on this line. While YAML is often tolerant of this, it's best practice to remove it to avoid potential parsing issues and maintain a clean configuration file.

      method: triple


- id: generate
op_name: generate
type: map_batch
dependencies:
- partition
execution_params:
replicas: 1
batch_size: 128
save_output: true # save output
params:
method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa
data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs
6 changes: 6 additions & 0 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,10 @@ def format_generation_results(
{"role": "assistant", "content": answer},
]
}

if output_data_format == "QA_pairs":
return {
"question": question,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace after question,. Please remove it to maintain code style consistency.

Suggested change
"question": question,
"question": question,

"answer": answer,
}
raise ValueError(f"Unknown output data format: {output_data_format}")
4 changes: 4 additions & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
QuizGenerator,
TrueFalseGenerator,
VQAGenerator,
MaskedFillInBlankGenerator,
)
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
from .llm import HTTPClient, OllamaClient, OpenAIClient
Expand All @@ -30,6 +31,7 @@
DFSPartitioner,
ECEPartitioner,
LeidenPartitioner,
TriplePartitioner,
)
from .reader import (
CSVReader,
Expand Down Expand Up @@ -71,6 +73,7 @@
"QuizGenerator": ".generator",
"TrueFalseGenerator": ".generator",
"VQAGenerator": ".generator",
"MaskedFillInBlankGenerator": ".generator",
# KG Builder
"LightRAGKGBuilder": ".kg_builder",
"MMKGBuilder": ".kg_builder",
Expand All @@ -84,6 +87,7 @@
"DFSPartitioner": ".partitioner",
"ECEPartitioner": ".partitioner",
"LeidenPartitioner": ".partitioner",
"TriplePartitioner": ".partitioner",
# Reader
"CSVReader": ".reader",
"JSONReader": ".reader",
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .quiz_generator import QuizGenerator
from .true_false_generator import TrueFalseGenerator
from .vqa_generator import VQAGenerator
from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
121 changes: 121 additions & 0 deletions graphgen/models/generator/masked_fill_in_blank_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import random
import re
from typing import Any, Optional

from graphgen.bases import BaseGenerator
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger

random.seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation, which can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.choice on line 103.



class MaskedFillInBlankGenerator(BaseGenerator):
"""
Masked Fill-in-blank Generator follows a TWO-STEP process:
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
"""

@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""
Build prompts for REPHRASE.
:param batch
:return:
"""
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relations_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relations_str)

# TODO: configure add_context
# if add_context:
# original_ids = [
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
# original_ids = list(set(original_ids))
# original_text = await text_chunks_storage.get_by_ids(original_ids)
# original_text = "\n".join(
# [
# f"{index + 1}. {text['content']}"
# for index, text in enumerate(original_text)
# ]
# )
Comment on lines +43 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out code seems to be related to a future feature (add_context). It's better to remove commented-out code from the codebase to improve readability. If this logic is needed for future reference, it should be tracked in an issue or a separate branch.

prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
entities=entities_str, relationships=relations_str
)
return prompt

@staticmethod
def parse_rephrased_text(response: str) -> Optional[str]:
"""
Parse the rephrased text from the response.
:param response:
:return: rephrased text
"""
rephrased_match = re.search(
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
)
if rephrased_match:
rephrased_text = rephrased_match.group(1).strip()
else:
logger.warning("Failed to parse rephrased text from response: %s", response)
return None
return rephrased_text.strip('"').strip("'")

@staticmethod
def parse_response(response: str) -> dict:
pass
Comment on lines +78 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parse_response method is defined as an abstract method in the BaseGenerator class but is implemented with pass here. Additionally, the return type hint dict is incompatible with the base class's list[dict]. Since this method is not used in the overridden generate method, it should either be implemented correctly or raise NotImplementedError to adhere to the abstract base class contract.

Suggested change
@staticmethod
def parse_response(response: str) -> dict:
pass
@staticmethod
def parse_response(response: str) -> list[dict]:
raise NotImplementedError("This method is not used in MaskedFillInBlankGenerator as it overrides the `generate` method.")


async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> list[dict]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
rephrasing_prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(rephrasing_prompt)
context = self.parse_rephrased_text(response)
if not context:
return []

nodes, edge = batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable edge is used to unpack the second item from batch, but based on the assertion on line 100 (len(edge) == 1), it appears to be a list of edges. For clarity and to avoid confusion, consider renaming it to edges here and on line 100.

Suggested change
nodes, edge = batch
nodes, edges = batch

assert (
len(nodes) == 2
), "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
assert (
len(edge) == 1
), "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."

node1, node2 = nodes
mask_node = random.choice([node1, node2])
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")

mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
masked_context = mask_pattern.sub("___", context)
# For accuracy, extract the actual replaced text from the context as the ground truth
gth = re.search(mask_pattern, context).group(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to re.search(mask_pattern, context) can return None if the mask_node_name is not found in the context. This would cause a crash with an AttributeError when .group(0) is called. You should add a check to handle this case gracefully, for example by logging a warning and returning an empty list.

Suggested change
gth = re.search(mask_pattern, context).group(0)
match = re.search(mask_pattern, context)
if not match:
logger.warning(
"Could not find mask_node_name '%s' in the rephrased context. Context: %s",
mask_node_name,
context,
)
return []
gth = match.group(0)


logger.debug("masked_context: %s", masked_context)
qa_pairs = {
"question": masked_context,
"answer": gth,
}
return [qa_pairs]
1 change: 1 addition & 0 deletions graphgen/models/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .dfs_partitioner import DFSPartitioner
from .ece_partitioner import ECEPartitioner
from .leiden_partitioner import LeidenPartitioner
from .triple_partitioner import TriplePartitioner
58 changes: 58 additions & 0 deletions graphgen/models/partitioner/triple_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import random
from collections import deque
from typing import Any, Iterable, Set

from graphgen.bases import BaseGraphStorage, BasePartitioner
from graphgen.bases.datatypes import Community

random.seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation. This can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.shuffle.



class TriplePartitioner(BasePartitioner):
"""
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
1. Automatically ignore isolated points.
2. In each connected component, yield triples in the order of BFS.
"""

def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
) -> Iterable[Community]:
nodes = [n[0] for n in g.get_all_nodes()]
random.shuffle(nodes)

visited_nodes: Set[str] = set()
used_edges: Set[frozenset[str]] = set()

for seed in nodes:
if seed in visited_nodes:
continue

# start BFS in a connected component
queue = deque([seed])
visited_nodes.add(seed)

while queue:
u = queue.popleft()

for v in g.get_neighbors(u):
edge_key = frozenset((u, v))

# if this edge has not been used, a new triple has been found
if edge_key not in used_edges:
used_edges.add(edge_key)

# use the edge name to ensure the uniqueness of the ID
u_sorted, v_sorted = sorted((u, v))
yield Community(
id=f"{u_sorted}-{v_sorted}",
nodes=[u_sorted, v_sorted],
edges=[(u_sorted, v_sorted)],
)

# continue to BFS
if v not in visited_nodes:
visited_nodes.add(v)
queue.append(v)
4 changes: 4 additions & 0 deletions graphgen/operators/generate/generate_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(
self.llm_client,
num_of_questions=generate_kwargs.get("num_of_questions", 5),
)
elif self.method == "masked_fill_in_blank":
from graphgen.models import MaskedFillInBlankGenerator

self.generator = MaskedFillInBlankGenerator(self.llm_client)
elif self.method == "true_false":
from graphgen.models import TrueFalseGenerator

Expand Down
6 changes: 5 additions & 1 deletion graphgen/operators/partition/partition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(

self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
method = partition_kwargs["method"]
self.method_params = partition_kwargs["method_params"]
self.method_params = partition_kwargs.get("method_params", {})

if method == "bfs":
from graphgen.models import BFSPartitioner
Expand Down Expand Up @@ -57,6 +57,10 @@ def __init__(
if self.method_params.get("anchor_ids")
else None,
)
elif method == "triple":
from graphgen.models import TriplePartitioner

self.partitioner = TriplePartitioner()
else:
raise ValueError(f"Unsupported partition method: {method}")

Expand Down