Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pvalplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
def hello() -> str:
return "Hello from pvalplot!"
return "Hello from pvalplot!"
81 changes: 81 additions & 0 deletions src/pvalplot/annotation_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from .annotation_sample import add_sample_annot2
from .annotation_hline import add_horizontal_line_annot
from .annotation_pvalues import add_pbars_annot

def wrapped(original_func):
"""Wrapper to add functionalities to default sns plot function calls. Use example:

from matplotlib import pyplot as plt
import seaborn as sns
from pvalplot.annotation_wrapper import wrapped

sns.boxplot = wrapped(sns.boxplot)
df = sns.load_dataset("tips")
sns.boxplot(
y="total_bill",
x="day",
data=df,
sample_size_annotation=True,
horizontal_line = "median",
pbars=dict({
'font_ylim': None,
'yshift_c': None,
'font_ndigits': 2,
'group': None,
'hue_col': None
}),
)

"""

def wrapper(*args, **kwargs):

# Pop out custom variables
sample_size_annotation = kwargs.pop('sample_size_annotation', None)
horizontal_line = kwargs.pop('horizontal_line', None)
pbars = kwargs.pop('pbars', None)

# Original SNS function call
ax = original_func(*args, **kwargs)

# Access original arguments
group_col = kwargs.get("x", args[0] if len(args) > 0 else None)
val_col = kwargs.get("y", args[1] if len(args) > 0 else None)
data = kwargs.get("data", args[2] if len(args) > 2 else None)

# Add p value brackets
if pbars is not None:

font_ylim = pbars.get('font_ylim', None)
yshift_c = pbars.get('yshift_c', None)
font_ndigits = pbars.get('font_ndigits', None)
group = pbars.get('group', None)
hue_col = pbars.get('hue_col', None)

add_pbars_annot(ax, data, group_col, val_col, font_ylim, yshift_c, font_ndigits, group, hue_col)

# Add sample size annotations
if sample_size_annotation:
ax = add_sample_annot2(ax, data, group_col)

# Add horizontal lines
if horizontal_line is not None:
if horizontal_line == True:
horizontal_line = "mean"

# Get colors of the box plots
palette = [
patch.get_facecolor()
for patch in ax.patches
if isinstance(patch, mpatches.PathPatch)]

add_horizontal_line_annot(ax, data, group_col, val_col, palette, horizontal_line)

return ax #input type and annotation type as well.
return ax

return wrapper