diff --git a/examples/generate/generate_masked_fill_in_blank_qa/README.md b/examples/generate/generate_masked_fill_in_blank_qa/README.md new file mode 100644 index 00000000..3251d5b9 --- /dev/null +++ b/examples/generate/generate_masked_fill_in_blank_qa/README.md @@ -0,0 +1,2 @@ +# Generate Masked Fill-in-blank QAs +# TODO diff --git a/examples/generate/generate_masked_fill_in_blank_qa/generate_masked_fill_in_blank.sh b/examples/generate/generate_masked_fill_in_blank_qa/generate_masked_fill_in_blank.sh new file mode 100644 index 00000000..c974bffa --- /dev/null +++ b/examples/generate/generate_masked_fill_in_blank_qa/generate_masked_fill_in_blank.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml diff --git a/examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml b/examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml new file mode 100644 index 00000000..0054c6c8 --- /dev/null +++ b/examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml @@ -0,0 +1,54 @@ +global_params: + working_dir: cache + graph_backend: networkx # graph database backend, support: kuzu, networkx + kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: triple + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true # save output + params: + method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa + data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index eb204535..b83be604 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -74,4 +74,10 @@ def format_generation_results( {"role": "assistant", "content": answer}, ] } + + if output_data_format == "QA_pairs": + return { + "question": question, + "answer": answer, + } raise ValueError(f"Unknown output data format: {output_data_format}") diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 2381d9b1..4c94b467 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -21,6 +21,7 @@ QuizGenerator, TrueFalseGenerator, VQAGenerator, + MaskedFillInBlankGenerator, ) from .kg_builder import LightRAGKGBuilder, MMKGBuilder from .llm import HTTPClient, OllamaClient, OpenAIClient @@ -30,6 +31,7 @@ DFSPartitioner, ECEPartitioner, LeidenPartitioner, + TriplePartitioner, ) from .reader import ( CSVReader, @@ -71,6 +73,7 @@ "QuizGenerator": ".generator", "TrueFalseGenerator": ".generator", "VQAGenerator": ".generator", + "MaskedFillInBlankGenerator": ".generator", # KG Builder "LightRAGKGBuilder": ".kg_builder", "MMKGBuilder": ".kg_builder", @@ -84,6 +87,7 @@ "DFSPartitioner": ".partitioner", "ECEPartitioner": ".partitioner", "LeidenPartitioner": ".partitioner", + "TriplePartitioner": ".partitioner", # Reader "CSVReader": ".reader", "JSONReader": ".reader", diff --git a/graphgen/models/generator/__init__.py b/graphgen/models/generator/__init__.py index 8562c34b..6fd25629 100644 --- a/graphgen/models/generator/__init__.py +++ b/graphgen/models/generator/__init__.py @@ -8,3 +8,4 @@ from .quiz_generator import QuizGenerator from .true_false_generator import TrueFalseGenerator from .vqa_generator import VQAGenerator +from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator diff --git a/graphgen/models/generator/masked_fill_in_blank_generator.py b/graphgen/models/generator/masked_fill_in_blank_generator.py new file mode 100644 index 00000000..6254c874 --- /dev/null +++ b/graphgen/models/generator/masked_fill_in_blank_generator.py @@ -0,0 +1,121 @@ +import random +import re +from typing import Any, Optional + +from graphgen.bases import BaseGenerator +from graphgen.templates import AGGREGATED_GENERATION_PROMPT +from graphgen.utils import detect_main_language, logger + +random.seed(42) + + +class MaskedFillInBlankGenerator(BaseGenerator): + """ + Masked Fill-in-blank Generator follows a TWO-STEP process: + 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning. + 2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text. + """ + + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """ + Build prompts for REPHRASE. + :param batch + :return: + """ + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + relations_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + language = detect_main_language(entities_str + relations_str) + + # TODO: configure add_context + # if add_context: + # original_ids = [ + # node["source_id"].split("")[0] for node in _process_nodes + # ] + [edge[2]["source_id"].split("")[0] for edge in _process_edges] + # original_ids = list(set(original_ids)) + # original_text = await text_chunks_storage.get_by_ids(original_ids) + # original_text = "\n".join( + # [ + # f"{index + 1}. {text['content']}" + # for index, text in enumerate(original_text) + # ] + # ) + prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format( + entities=entities_str, relationships=relations_str + ) + return prompt + + @staticmethod + def parse_rephrased_text(response: str) -> Optional[str]: + """ + Parse the rephrased text from the response. + :param response: + :return: rephrased text + """ + rephrased_match = re.search( + r"(.*?)", response, re.DOTALL + ) + if rephrased_match: + rephrased_text = rephrased_match.group(1).strip() + else: + logger.warning("Failed to parse rephrased text from response: %s", response) + return None + return rephrased_text.strip('"').strip("'") + + @staticmethod + def parse_response(response: str) -> dict: + pass + + async def generate( + self, + batch: tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ], + ) -> list[dict]: + """ + Generate QAs based on a given batch. + :param batch + :return: QA pairs + """ + rephrasing_prompt = self.build_prompt(batch) + response = await self.llm_client.generate_answer(rephrasing_prompt) + context = self.parse_rephrased_text(response) + if not context: + return [] + + nodes, edge = batch + assert ( + len(nodes) == 2 + ), "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes." + assert ( + len(edge) == 1 + ), "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge." + + node1, node2 = nodes + mask_node = random.choice([node1, node2]) + mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t") + + mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE) + masked_context = mask_pattern.sub("___", context) + # For accuracy, extract the actual replaced text from the context as the ground truth + gth = re.search(mask_pattern, context).group(0) + + logger.debug("masked_context: %s", masked_context) + qa_pairs = { + "question": masked_context, + "answer": gth, + } + return [qa_pairs] diff --git a/graphgen/models/partitioner/__init__.py b/graphgen/models/partitioner/__init__.py index 2e1bcb68..b13dfc6d 100644 --- a/graphgen/models/partitioner/__init__.py +++ b/graphgen/models/partitioner/__init__.py @@ -3,3 +3,4 @@ from .dfs_partitioner import DFSPartitioner from .ece_partitioner import ECEPartitioner from .leiden_partitioner import LeidenPartitioner +from .triple_partitioner import TriplePartitioner diff --git a/graphgen/models/partitioner/triple_partitioner.py b/graphgen/models/partitioner/triple_partitioner.py new file mode 100644 index 00000000..2bdfe8d5 --- /dev/null +++ b/graphgen/models/partitioner/triple_partitioner.py @@ -0,0 +1,58 @@ +import random +from collections import deque +from typing import Any, Iterable, Set + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + +random.seed(42) + + +class TriplePartitioner(BasePartitioner): + """ + Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node). + 1. Automatically ignore isolated points. + 2. In each connected component, yield triples in the order of BFS. + """ + + def partition( + self, + g: BaseGraphStorage, + **kwargs: Any, + ) -> Iterable[Community]: + nodes = [n[0] for n in g.get_all_nodes()] + random.shuffle(nodes) + + visited_nodes: Set[str] = set() + used_edges: Set[frozenset[str]] = set() + + for seed in nodes: + if seed in visited_nodes: + continue + + # start BFS in a connected component + queue = deque([seed]) + visited_nodes.add(seed) + + while queue: + u = queue.popleft() + + for v in g.get_neighbors(u): + edge_key = frozenset((u, v)) + + # if this edge has not been used, a new triple has been found + if edge_key not in used_edges: + used_edges.add(edge_key) + + # use the edge name to ensure the uniqueness of the ID + u_sorted, v_sorted = sorted((u, v)) + yield Community( + id=f"{u_sorted}-{v_sorted}", + nodes=[u_sorted, v_sorted], + edges=[(u_sorted, v_sorted)], + ) + + # continue to BFS + if v not in visited_nodes: + visited_nodes.add(v) + queue.append(v) diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index 1868a50e..18ce1d43 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -71,6 +71,10 @@ def __init__( self.llm_client, num_of_questions=generate_kwargs.get("num_of_questions", 5), ) + elif self.method == "masked_fill_in_blank": + from graphgen.models import MaskedFillInBlankGenerator + + self.generator = MaskedFillInBlankGenerator(self.llm_client) elif self.method == "true_false": from graphgen.models import TrueFalseGenerator diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index dfadf8da..e2ff6789 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -28,7 +28,7 @@ def __init__( self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) method = partition_kwargs["method"] - self.method_params = partition_kwargs["method_params"] + self.method_params = partition_kwargs.get("method_params", {}) if method == "bfs": from graphgen.models import BFSPartitioner @@ -57,6 +57,10 @@ def __init__( if self.method_params.get("anchor_ids") else None, ) + elif method == "triple": + from graphgen.models import TriplePartitioner + + self.partitioner = TriplePartitioner() else: raise ValueError(f"Unsupported partition method: {method}")