Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions efaar_benchmarking/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from importlib.metadata import version


def get_version() -> str:
"""Returns a string representation of the version of efaar_benchmarking currently in use

Expand All @@ -7,15 +10,6 @@ def get_version() -> str:
the version number installed of this package
"""
try:
from importlib.metadata import version # type: ignore

return version("efaar_benchmarking")
except ImportError:
try:
import pkg_resources

return pkg_resources.get_distribution("efaar_benchmarking").version
except pkg_resources.DistributionNotFound:
return "set_version_placeholder"
except ModuleNotFoundError:
return "set_version_placeholder"
135 changes: 108 additions & 27 deletions notebooks/map_building_benchmarking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
"import pickle\n",
"\n",
"pc_count = 128\n",
"save_results = False # Results already uploaded to the notebooks/data folder in the repo. If True, will replace these files.\n",
"save_results = (\n",
" False # Results already uploaded to the notebooks/data folder in the repo. If True, will replace these files.\n",
")\n",
"pert_signal_pval_cutoff = 0.05\n",
"recall_thr_pairs = [(.05, .95)]"
"recall_thr_pairs = [(0.05, 0.95)]"
]
},
{
Expand Down Expand Up @@ -53,17 +55,43 @@
"# Run EFAAR pipelines\n",
"all_embeddings_pre_agg = {}\n",
"print(\"Running for embedding size\", pc_count)\n",
"all_embeddings_pre_agg[f\"scVI{pc_count}\"] = embed_by_scvi_anndata(adata_raw, batch_col=gem_group_colname, n_latent=pc_count, n_hidden=pc_count*2)\n",
"all_embeddings_pre_agg[f\"scVI{pc_count}\"] = embed_by_scvi_anndata(\n",
" adata_raw, batch_col=gem_group_colname, n_latent=pc_count, n_hidden=pc_count * 2\n",
")\n",
"print(\"embed_by_scvi_anndata completed\")\n",
"all_embeddings_pre_agg[f\"scVI{pc_count}-CS\"] = centerscale_on_controls(all_embeddings_pre_agg[f\"scVI{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)\n",
"all_embeddings_pre_agg[f\"scVI{pc_count}-CS\"] = centerscale_on_controls(\n",
" all_embeddings_pre_agg[f\"scVI{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=gem_group_colname,\n",
")\n",
"print(\"centerscale completed\")\n",
"all_embeddings_pre_agg[f\"scVI{pc_count}-TVN\"] = tvn_on_controls(all_embeddings_pre_agg[f\"scVI{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)\n",
"all_embeddings_pre_agg[f\"scVI{pc_count}-TVN\"] = tvn_on_controls(\n",
" all_embeddings_pre_agg[f\"scVI{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=gem_group_colname,\n",
")\n",
"print(\"tvn completed\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}\"] = embed_by_pca_anndata(adata_raw, gem_group_colname, pc_count)\n",
"print(\"embed_by_pca_anndata completed\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-CS\"] = centerscale_on_controls(all_embeddings_pre_agg[f\"PCA{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-CS\"] = centerscale_on_controls(\n",
" all_embeddings_pre_agg[f\"PCA{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=gem_group_colname,\n",
")\n",
"print(\"centerscale completed\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-TVN\"] = tvn_on_controls(all_embeddings_pre_agg[f\"PCA{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-TVN\"] = tvn_on_controls(\n",
" all_embeddings_pre_agg[f\"PCA{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=gem_group_colname,\n",
")\n",
"print(\"tvn completed\")\n",
"\n",
"# Run biological relationship benchmarks\n",
Expand All @@ -77,9 +105,9 @@
"\n",
"# Save results\n",
"if save_results:\n",
" with open(f'data/{dataset}_map_cache.pkl', 'wb') as f:\n",
" with open(f\"data/{dataset}_map_cache.pkl\", \"wb\") as f:\n",
" pickle.dump(map_data, f) # storing the PCA-TVN map data for downstream analysis\n",
" with open(f'data/{dataset}_metadata.pkl', 'wb') as f:\n",
" with open(f\"data/{dataset}_metadata.pkl\", \"wb\") as f:\n",
" pickle.dump(metadata, f) # storing the metadata for downstream analysis"
]
},
Expand Down Expand Up @@ -109,26 +137,54 @@
"features, metadata = filter_cell_profiler_features(features, metadata)\n",
"\n",
"expression_data_folder = \"../efaar_benchmarking/expression_data\"\n",
"expr = pd.read_csv(f\"{expression_data_folder}/U2OS_expression.csv\", index_col=0).groupby(\"gene\").zfpkm.agg(\"median\").reset_index()\n",
"expr = (\n",
" pd.read_csv(f\"{expression_data_folder}/U2OS_expression.csv\", index_col=0)\n",
" .groupby(\"gene\")\n",
" .zfpkm.agg(\"median\")\n",
" .reset_index()\n",
")\n",
"unexpr_genes = list(expr.loc[expr.zfpkm < -3, \"gene\"])\n",
"expr_genes = list(expr.loc[expr.zfpkm >= -3, \"gene\"])\n",
"expr_ind = metadata[pert_colname].isin(expr_genes + [control_key])\n",
"\n",
"# Run EFAAR pipelines\n",
"all_embeddings_pre_agg = {}\n",
"print(\"Computing PCA embedding for\", pc_count, \"dimensions...\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}\"] = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=plate_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}\"] = embed_by_pca(\n",
" features.values, metadata, variance_or_ncomp=pc_count, batch_col=plate_colname\n",
")\n",
"print(\"Computing centerscale...\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-CS\"] = centerscale_on_controls(all_embeddings_pre_agg[f\"PCA{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=run_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-CS\"] = centerscale_on_controls(\n",
" all_embeddings_pre_agg[f\"PCA{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=run_colname,\n",
")\n",
"print(\"Computing TVN...\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-TVN\"] = tvn_on_controls(all_embeddings_pre_agg[f\"PCA{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=run_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-TVN\"] = tvn_on_controls(\n",
" all_embeddings_pre_agg[f\"PCA{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=run_colname,\n",
")\n",
"\n",
"# Run perturbation signal benchmarks\n",
"for k, emb in all_embeddings_pre_agg.items():\n",
" cons_res = pert_signal_consistency_benchmark(emb, metadata, pert_col=pert_colname, neg_ctrl_perts=unexpr_genes, keys_to_drop=all_controls)\n",
" cons_res = pert_signal_consistency_benchmark(\n",
" emb, metadata, pert_col=pert_colname, neg_ctrl_perts=unexpr_genes, keys_to_drop=all_controls\n",
" )\n",
" print(k, round(sum(cons_res.pval <= pert_signal_pval_cutoff) / sum(~pd.isna(cons_res.pval)) * 100, 1))\n",
"\n",
" magn_res = pert_signal_distance_benchmark(emb, metadata, pert_col=pert_colname, neg_ctrl_perts=unexpr_genes, control_key=control_key, keys_to_drop=[x for x in all_controls if x!=control_key])\n",
" magn_res = pert_signal_distance_benchmark(\n",
" emb,\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" neg_ctrl_perts=unexpr_genes,\n",
" control_key=control_key,\n",
" keys_to_drop=[x for x in all_controls if x != control_key],\n",
" )\n",
" print(k, round(sum(magn_res.pval <= pert_signal_pval_cutoff) / sum(~pd.isna(magn_res.pval)) * 100, 1))\n",
"\n",
"# Run biological relationship benchmarks\n",
Expand All @@ -137,14 +193,14 @@
" print(\"Aggregating...\")\n",
" map_data = aggregate(emb[expr_ind], metadata[expr_ind], pert_col=pert_colname, keys_to_remove=all_controls)\n",
" print(\"Computing recall...\")\n",
" metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(.05, .95)], pert_col=pert_colname)\n",
" metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(0.05, 0.95)], pert_col=pert_colname)\n",
" print(metrics[list(metrics.columns)[::-1]])\n",
"\n",
"# Save results\n",
"if save_results:\n",
" with open(f'data/{dataset}_map_cache.pkl', 'wb') as f:\n",
" with open(f\"data/{dataset}_map_cache.pkl\", \"wb\") as f:\n",
" pickle.dump(map_data, f) # storing the PCA-TVN map data for downstream analysis\n",
" with open(f'data/{dataset}_metadata.pkl', 'wb') as f:\n",
" with open(f\"data/{dataset}_metadata.pkl\", \"wb\") as f:\n",
" pickle.dump(metadata, f) # storing the metadata for downstream analysis"
]
},
Expand Down Expand Up @@ -172,7 +228,9 @@
"print(\"Perturbation dataset loaded\")\n",
"\n",
"expression_data_folder = \"../efaar_benchmarking/expression_data\"\n",
"expr = pd.read_csv(f\"{expression_data_folder}/HeLa_expression.csv\") # note that we assume the HeLa expression data was used for PERISCOPE which is the default option in load_periscope()\n",
"expr = pd.read_csv(\n",
" f\"{expression_data_folder}/HeLa_expression.csv\"\n",
") # note that we assume the HeLa expression data was used for PERISCOPE which is the default option in load_periscope()\n",
"expr.columns = [\"gene\", \"tpm\"]\n",
"expr.gene = expr.gene.apply(lambda x: x.split(\" \")[0])\n",
"unexpr_genes = list(expr.loc[expr.tpm == 0, \"gene\"])\n",
Expand All @@ -182,18 +240,41 @@
"# Run EFAAR pipelines\n",
"all_embeddings_pre_agg = {}\n",
"print(\"Computing PCA embedding for\", pc_count, \"dimensions...\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}\"] = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=plate_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}\"] = embed_by_pca(\n",
" features.values, metadata, variance_or_ncomp=pc_count, batch_col=plate_colname\n",
")\n",
"print(\"Computing centerscale...\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-CS\"] = centerscale_on_controls(all_embeddings_pre_agg[f\"PCA{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=plate_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-CS\"] = centerscale_on_controls(\n",
" all_embeddings_pre_agg[f\"PCA{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=plate_colname,\n",
")\n",
"print(\"Computing TVN...\")\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-TVN\"] = tvn_on_controls(all_embeddings_pre_agg[f\"PCA{pc_count}\"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=plate_colname)\n",
"all_embeddings_pre_agg[f\"PCA{pc_count}-TVN\"] = tvn_on_controls(\n",
" all_embeddings_pre_agg[f\"PCA{pc_count}\"],\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" control_key=control_key,\n",
" batch_col=plate_colname,\n",
")\n",
"\n",
"# Run perturbation signal benchmarks\n",
"for k, emb in all_embeddings_pre_agg.items():\n",
" cons_res = pert_signal_consistency_benchmark(emb, metadata, pert_col=pert_colname, neg_ctrl_perts=unexpr_genes, keys_to_drop=all_controls)\n",
" cons_res = pert_signal_consistency_benchmark(\n",
" emb, metadata, pert_col=pert_colname, neg_ctrl_perts=unexpr_genes, keys_to_drop=all_controls\n",
" )\n",
" print(k, round(sum(cons_res.pval <= pert_signal_pval_cutoff) / sum(~pd.isna(cons_res.pval)) * 100, 1))\n",
"\n",
" magn_res = pert_signal_distance_benchmark(emb, metadata, pert_col=pert_colname, neg_ctrl_perts=unexpr_genes, control_key=control_key, keys_to_drop=[x for x in all_controls if x!=control_key])\n",
" magn_res = pert_signal_distance_benchmark(\n",
" emb,\n",
" metadata,\n",
" pert_col=pert_colname,\n",
" neg_ctrl_perts=unexpr_genes,\n",
" control_key=control_key,\n",
" keys_to_drop=[x for x in all_controls if x != control_key],\n",
" )\n",
" print(k, round(sum(magn_res.pval <= pert_signal_pval_cutoff) / sum(~pd.isna(magn_res.pval)) * 100, 1))\n",
"\n",
"# Run biological relationship benchmarks\n",
Expand All @@ -202,14 +283,14 @@
" print(\"Aggregating...\")\n",
" map_data = aggregate(emb[expr_ind], metadata[expr_ind], pert_col=pert_colname, keys_to_remove=all_controls)\n",
" print(\"Computing recall...\")\n",
" metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(.05, .95)], pert_col=pert_colname)\n",
" metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(0.05, 0.95)], pert_col=pert_colname)\n",
" print(metrics[list(metrics.columns)[::-1]])\n",
"\n",
"# Save results\n",
"if save_results:\n",
" with open(f'data/{dataset}_map_cache.pkl', 'wb') as f:\n",
" with open(f\"data/{dataset}_map_cache.pkl\", \"wb\") as f:\n",
" pickle.dump(map_data, f) # storing the PCA-TVN map data for downstream analysis\n",
" with open(f'data/{dataset}_metadata.pkl', 'wb') as f:\n",
" with open(f\"data/{dataset}_metadata.pkl\", \"wb\") as f:\n",
" pickle.dump(metadata, f) # storing the metadata for downstream analysis"
]
}
Expand Down
Loading