diff --git a/examples/search/search_dna/search_dna_config.yaml b/examples/search/search_dna/search_dna_config.yaml index db87b16e..81bbfb37 100644 --- a/examples/search/search_dna/search_dna_config.yaml +++ b/examples/search/search_dna/search_dna_config.yaml @@ -22,7 +22,7 @@ nodes: batch_size: 10 save_output: true params: - data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral + data_source: ncbi # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral ncbi_params: email: test@example.com # NCBI requires an email address tool: GraphGen # tool name for NCBI API diff --git a/examples/search/search_protein/search_protein_config.yaml b/examples/search/search_protein/search_protein_config.yaml index 6e6f085c..bbf42abd 100644 --- a/examples/search/search_protein/search_protein_config.yaml +++ b/examples/search/search_protein/search_protein_config.yaml @@ -22,7 +22,7 @@ nodes: batch_size: 10 save_output: true params: - data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot + data_source: uniprot # data source for searcher, support: wikipedia, google, uniprot uniprot_params: use_local_blast: true # whether to use local blast for uniprot search local_blast_db: /path/to/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot diff --git a/examples/search/search_rna/search_rna_config.yaml b/examples/search/search_rna/search_rna_config.yaml index c19793e8..5c02a484 100644 --- a/examples/search/search_rna/search_rna_config.yaml +++ b/examples/search/search_rna/search_rna_config.yaml @@ -22,7 +22,7 @@ nodes: batch_size: 10 save_output: true params: - data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral + data_source: rnacentral # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral rnacentral_params: use_local_blast: true # whether to use local blast for RNA search local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension) diff --git a/graphgen/operators/search/search_service.py b/graphgen/operators/search/search_service.py index 7e25e225..1a599e25 100644 --- a/graphgen/operators/search/search_service.py +++ b/graphgen/operators/search/search_service.py @@ -1,9 +1,9 @@ from functools import partial -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple from graphgen.bases import BaseOperator from graphgen.common.init_storage import init_storage -from graphgen.utils import compute_content_hash, logger, run_concurrent +from graphgen.utils import logger, run_concurrent if TYPE_CHECKING: import pandas as pd @@ -19,42 +19,47 @@ def __init__( self, working_dir: str = "cache", kv_backend: str = "rocksdb", - data_sources: list = None, + data_source: str = None, **kwargs, ): - super().__init__(working_dir=working_dir, op_name="search_service") - self.working_dir = working_dir - self.data_sources = data_sources or [] + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="search" + ) + self.data_source = data_source self.kwargs = kwargs self.search_storage = init_storage( backend=kv_backend, working_dir=working_dir, namespace="search" ) - self.searchers = {} + self.searcher = None - def _init_searchers(self): + def _init_searcher(self): """ - Initialize all searchers (deferred import to avoid circular imports). + Initialize the searcher (deferred import to avoid circular imports). """ - for datasource in self.data_sources: - if datasource in self.searchers: - continue - if datasource == "uniprot": - from graphgen.models import UniProtSearch + if self.searcher is not None: + return + + if not self.data_source: + logger.error("Data source not specified") + return - params = self.kwargs.get("uniprot_params", {}) - self.searchers[datasource] = UniProtSearch(**params) - elif datasource == "ncbi": - from graphgen.models import NCBISearch + if self.data_source == "uniprot": + from graphgen.models import UniProtSearch - params = self.kwargs.get("ncbi_params", {}) - self.searchers[datasource] = NCBISearch(**params) - elif datasource == "rnacentral": - from graphgen.models import RNACentralSearch + params = self.kwargs.get("uniprot_params", {}) + self.searcher = UniProtSearch(**params) + elif self.data_source == "ncbi": + from graphgen.models import NCBISearch - params = self.kwargs.get("rnacentral_params", {}) - self.searchers[datasource] = RNACentralSearch(**params) - else: - logger.error(f"Unknown data source: {datasource}, skipping") + params = self.kwargs.get("ncbi_params", {}) + self.searcher = NCBISearch(**params) + elif self.data_source == "rnacentral": + from graphgen.models import RNACentralSearch + + params = self.kwargs.get("rnacentral_params", {}) + self.searcher = RNACentralSearch(**params) + else: + logger.error(f"Unknown data source: {self.data_source}") @staticmethod async def _perform_search( @@ -76,91 +81,59 @@ async def _perform_search( result = searcher_obj.search(query) if result: - result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-") result["data_source"] = data_source result["type"] = seed.get("type", "text") return result - def _process_single_source( - self, data_source: str, seed_data: list[dict] - ) -> list[dict]: - """ - process a single data source: check cache, search missing, update cache. + def process(self, batch: list) -> Tuple[list, dict]: """ - searcher = self.searchers[data_source] - - seeds_with_ids = [] - for seed in seed_data: - query = seed.get("content", "") - if not query: - continue - doc_id = compute_content_hash(str(data_source) + query, "doc-") - seeds_with_ids.append((doc_id, seed)) - - if not seeds_with_ids: - return [] - - doc_ids = [doc_id for doc_id, _ in seeds_with_ids] - cached_results = self.search_storage.get_by_ids(doc_ids) - - to_search_seeds = [] - final_results = [] + Search for items in the batch using the configured data source. - for (doc_id, seed), cached in zip(seeds_with_ids, cached_results): - if cached is not None: - if "_doc_id" not in cached: - cached["_doc_id"] = doc_id - final_results.append(cached) - else: - to_search_seeds.append(seed) - - if to_search_seeds: - new_results = run_concurrent( - partial( - self._perform_search, searcher_obj=searcher, data_source=data_source - ), - to_search_seeds, - desc=f"Searching {data_source} database", - unit="keyword", - ) - new_results = [res for res in new_results if res is not None] - - if new_results: - upsert_data = {res["_doc_id"]: res for res in new_results} - self.search_storage.upsert(upsert_data) - logger.info( - f"Saved {len(upsert_data)} new results to {data_source} cache" - ) - - final_results.extend(new_results) - - return final_results - - def process(self, batch: "pd.DataFrame") -> "pd.DataFrame": - import pandas as pd - - docs = batch.to_dict(orient="records") + :param batch: List of items with 'content' and '_trace_id' fields + :return: A tuple of (results, meta_updates) + results: A list of search results. + meta_updates: A dict mapping source IDs to lists of trace IDs for the search results. + """ + self._init_searcher() - self._init_searchers() + if not self.searcher: + logger.error("Searcher not initialized") + return [], {} - seed_data = [doc for doc in docs if doc and "content" in doc] + # Filter seeds with valid content and _trace_id + seed_data = [ + item for item in batch if item and "content" in item and "_trace_id" in item + ] if not seed_data: logger.warning("No valid seeds in batch") - return pd.DataFrame([]) - - all_results = [] + return [], {} + + # Perform concurrent searches + results = run_concurrent( + partial( + self._perform_search, + searcher_obj=self.searcher, + data_source=self.data_source, + ), + seed_data, + desc=f"Searching {self.data_source} database", + unit="keyword", + ) - for data_source in self.data_sources: - if data_source not in self.searchers: - logger.error(f"Data source {data_source} not initialized, skipping") + # Filter out None results and add _trace_id from original seeds + final_results = [] + meta_updates = {} + for result, seed in zip(results, seed_data): + if result is None: continue + result["_trace_id"] = self.get_trace_id(result) + final_results.append(result) + # Map from source seed trace ID to search result trace ID + meta_updates.setdefault(seed["_trace_id"], []).append(result["_trace_id"]) - source_results = self._process_single_source(data_source, seed_data) - all_results.extend(source_results) - - if not all_results: + if not final_results: logger.warning("No search results generated for this batch") - return pd.DataFrame(all_results) + return final_results, meta_updates