-
Notifications
You must be signed in to change notification settings - Fork 71
feat: support synthesizing masked fill_in_blank QA pairs #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Generate Masked Fill-in-blank QAs | ||
| # TODO | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| python3 -m graphgen.run \ | ||
| --config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| global_params: | ||
| working_dir: cache | ||
| graph_backend: networkx # graph database backend, support: kuzu, networkx | ||
| kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv | ||
|
|
||
| nodes: | ||
| - id: read_files # id is unique in the pipeline, and can be referenced by other steps | ||
| op_name: read | ||
| type: source | ||
| dependencies: [] | ||
| params: | ||
| input_path: | ||
| - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples | ||
|
|
||
| - id: chunk_documents | ||
| op_name: chunk | ||
| type: map_batch | ||
| dependencies: | ||
| - read_files | ||
| execution_params: | ||
| replicas: 4 | ||
| params: | ||
| chunk_size: 1024 # chunk size for text splitting | ||
| chunk_overlap: 100 # chunk overlap for text splitting | ||
|
|
||
| - id: build_kg | ||
| op_name: build_kg | ||
| type: map_batch | ||
| dependencies: | ||
| - chunk_documents | ||
| execution_params: | ||
| replicas: 1 | ||
| batch_size: 128 | ||
|
|
||
| - id: partition | ||
| op_name: partition | ||
| type: aggregate | ||
| dependencies: | ||
| - build_kg | ||
| params: | ||
| method: triple | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| - id: generate | ||
| op_name: generate | ||
| type: map_batch | ||
| dependencies: | ||
| - partition | ||
| execution_params: | ||
| replicas: 1 | ||
| batch_size: 128 | ||
| save_output: true # save output | ||
| params: | ||
| method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa | ||
| data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,4 +74,10 @@ def format_generation_results( | |
| {"role": "assistant", "content": answer}, | ||
| ] | ||
| } | ||
|
|
||
| if output_data_format == "QA_pairs": | ||
| return { | ||
| "question": question, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| "answer": answer, | ||
| } | ||
| raise ValueError(f"Unknown output data format: {output_data_format}") | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,121 @@ | ||||||||||||||||||||||
| import random | ||||||||||||||||||||||
| import re | ||||||||||||||||||||||
| from typing import Any, Optional | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from graphgen.bases import BaseGenerator | ||||||||||||||||||||||
| from graphgen.templates import AGGREGATED_GENERATION_PROMPT | ||||||||||||||||||||||
| from graphgen.utils import detect_main_language, logger | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| random.seed(42) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting a global random seed with |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class MaskedFillInBlankGenerator(BaseGenerator): | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Masked Fill-in-blank Generator follows a TWO-STEP process: | ||||||||||||||||||||||
| 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning. | ||||||||||||||||||||||
| 2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||
| def build_prompt( | ||||||||||||||||||||||
| batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] | ||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Build prompts for REPHRASE. | ||||||||||||||||||||||
| :param batch | ||||||||||||||||||||||
| :return: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| nodes, edges = batch | ||||||||||||||||||||||
| entities_str = "\n".join( | ||||||||||||||||||||||
| [ | ||||||||||||||||||||||
| f"{index + 1}. {node[0]}: {node[1]['description']}" | ||||||||||||||||||||||
| for index, node in enumerate(nodes) | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| relations_str = "\n".join( | ||||||||||||||||||||||
| [ | ||||||||||||||||||||||
| f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" | ||||||||||||||||||||||
| for index, edge in enumerate(edges) | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| language = detect_main_language(entities_str + relations_str) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # TODO: configure add_context | ||||||||||||||||||||||
| # if add_context: | ||||||||||||||||||||||
| # original_ids = [ | ||||||||||||||||||||||
| # node["source_id"].split("<SEP>")[0] for node in _process_nodes | ||||||||||||||||||||||
| # ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges] | ||||||||||||||||||||||
| # original_ids = list(set(original_ids)) | ||||||||||||||||||||||
| # original_text = await text_chunks_storage.get_by_ids(original_ids) | ||||||||||||||||||||||
| # original_text = "\n".join( | ||||||||||||||||||||||
| # [ | ||||||||||||||||||||||
| # f"{index + 1}. {text['content']}" | ||||||||||||||||||||||
| # for index, text in enumerate(original_text) | ||||||||||||||||||||||
| # ] | ||||||||||||||||||||||
| # ) | ||||||||||||||||||||||
|
Comment on lines
+43
to
+55
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||
| prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format( | ||||||||||||||||||||||
| entities=entities_str, relationships=relations_str | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return prompt | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||
| def parse_rephrased_text(response: str) -> Optional[str]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Parse the rephrased text from the response. | ||||||||||||||||||||||
| :param response: | ||||||||||||||||||||||
| :return: rephrased text | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| rephrased_match = re.search( | ||||||||||||||||||||||
| r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if rephrased_match: | ||||||||||||||||||||||
| rephrased_text = rephrased_match.group(1).strip() | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| logger.warning("Failed to parse rephrased text from response: %s", response) | ||||||||||||||||||||||
| return None | ||||||||||||||||||||||
| return rephrased_text.strip('"').strip("'") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||
| def parse_response(response: str) -> dict: | ||||||||||||||||||||||
| pass | ||||||||||||||||||||||
|
Comment on lines
+78
to
+80
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async def generate( | ||||||||||||||||||||||
| self, | ||||||||||||||||||||||
| batch: tuple[ | ||||||||||||||||||||||
| list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| ) -> list[dict]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Generate QAs based on a given batch. | ||||||||||||||||||||||
| :param batch | ||||||||||||||||||||||
| :return: QA pairs | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| rephrasing_prompt = self.build_prompt(batch) | ||||||||||||||||||||||
| response = await self.llm_client.generate_answer(rephrasing_prompt) | ||||||||||||||||||||||
| context = self.parse_rephrased_text(response) | ||||||||||||||||||||||
| if not context: | ||||||||||||||||||||||
| return [] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| nodes, edge = batch | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable
Suggested change
|
||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||
| len(nodes) == 2 | ||||||||||||||||||||||
| ), "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes." | ||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||
| len(edge) == 1 | ||||||||||||||||||||||
| ), "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge." | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| node1, node2 = nodes | ||||||||||||||||||||||
| mask_node = random.choice([node1, node2]) | ||||||||||||||||||||||
| mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE) | ||||||||||||||||||||||
| masked_context = mask_pattern.sub("___", context) | ||||||||||||||||||||||
| # For accuracy, extract the actual replaced text from the context as the ground truth | ||||||||||||||||||||||
| gth = re.search(mask_pattern, context).group(0) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The call to
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| logger.debug("masked_context: %s", masked_context) | ||||||||||||||||||||||
| qa_pairs = { | ||||||||||||||||||||||
| "question": masked_context, | ||||||||||||||||||||||
| "answer": gth, | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| return [qa_pairs] | ||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import random | ||
| from collections import deque | ||
| from typing import Any, Iterable, Set | ||
|
|
||
| from graphgen.bases import BaseGraphStorage, BasePartitioner | ||
| from graphgen.bases.datatypes import Community | ||
|
|
||
| random.seed(42) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting a global random seed with |
||
|
|
||
|
|
||
| class TriplePartitioner(BasePartitioner): | ||
| """ | ||
| Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node). | ||
| 1. Automatically ignore isolated points. | ||
| 2. In each connected component, yield triples in the order of BFS. | ||
| """ | ||
|
|
||
| def partition( | ||
| self, | ||
| g: BaseGraphStorage, | ||
| **kwargs: Any, | ||
| ) -> Iterable[Community]: | ||
| nodes = [n[0] for n in g.get_all_nodes()] | ||
| random.shuffle(nodes) | ||
|
|
||
| visited_nodes: Set[str] = set() | ||
| used_edges: Set[frozenset[str]] = set() | ||
|
|
||
| for seed in nodes: | ||
| if seed in visited_nodes: | ||
| continue | ||
|
|
||
| # start BFS in a connected component | ||
| queue = deque([seed]) | ||
| visited_nodes.add(seed) | ||
|
|
||
| while queue: | ||
| u = queue.popleft() | ||
|
|
||
| for v in g.get_neighbors(u): | ||
| edge_key = frozenset((u, v)) | ||
|
|
||
| # if this edge has not been used, a new triple has been found | ||
| if edge_key not in used_edges: | ||
| used_edges.add(edge_key) | ||
|
|
||
| # use the edge name to ensure the uniqueness of the ID | ||
| u_sorted, v_sorted = sorted((u, v)) | ||
| yield Community( | ||
| id=f"{u_sorted}-{v_sorted}", | ||
| nodes=[u_sorted, v_sorted], | ||
| edges=[(u_sorted, v_sorted)], | ||
| ) | ||
|
|
||
| # continue to BFS | ||
| if v not in visited_nodes: | ||
| visited_nodes.add(v) | ||
| queue.append(v) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The README file currently only contains "TODO". To make this example useful for other developers and users, please add a brief description of what this feature does, what the configuration options mean, and how to run it.