diff --git a/src/pvalplot/__init__.py b/src/pvalplot/__init__.py index d6648b6..d838a7b 100644 --- a/src/pvalplot/__init__.py +++ b/src/pvalplot/__init__.py @@ -1,2 +1,2 @@ def hello() -> str: - return "Hello from pvalplot!" + return "Hello from pvalplot!" \ No newline at end of file diff --git a/src/pvalplot/annotation_wrapper.py b/src/pvalplot/annotation_wrapper.py new file mode 100644 index 0000000..5c745b2 --- /dev/null +++ b/src/pvalplot/annotation_wrapper.py @@ -0,0 +1,81 @@ +import seaborn as sns +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +from .annotation_sample import add_sample_annot2 +from .annotation_hline import add_horizontal_line_annot +from .annotation_pvalues import add_pbars_annot + +def wrapped(original_func): + """Wrapper to add functionalities to default sns plot function calls. Use example: + + from matplotlib import pyplot as plt + import seaborn as sns + from pvalplot.annotation_wrapper import wrapped + + sns.boxplot = wrapped(sns.boxplot) + df = sns.load_dataset("tips") + sns.boxplot( + y="total_bill", + x="day", + data=df, + sample_size_annotation=True, + horizontal_line = "median", + pbars=dict({ + 'font_ylim': None, + 'yshift_c': None, + 'font_ndigits': 2, + 'group': None, + 'hue_col': None + }), + ) + + """ + + def wrapper(*args, **kwargs): + + # Pop out custom variables + sample_size_annotation = kwargs.pop('sample_size_annotation', None) + horizontal_line = kwargs.pop('horizontal_line', None) + pbars = kwargs.pop('pbars', None) + + # Original SNS function call + ax = original_func(*args, **kwargs) + + # Access original arguments + group_col = kwargs.get("x", args[0] if len(args) > 0 else None) + val_col = kwargs.get("y", args[1] if len(args) > 0 else None) + data = kwargs.get("data", args[2] if len(args) > 2 else None) + + # Add p value brackets + if pbars is not None: + + font_ylim = pbars.get('font_ylim', None) + yshift_c = pbars.get('yshift_c', None) + font_ndigits = pbars.get('font_ndigits', None) + group = pbars.get('group', None) + hue_col = pbars.get('hue_col', None) + + add_pbars_annot(ax, data, group_col, val_col, font_ylim, yshift_c, font_ndigits, group, hue_col) + + # Add sample size annotations + if sample_size_annotation: + ax = add_sample_annot2(ax, data, group_col) + + # Add horizontal lines + if horizontal_line is not None: + if horizontal_line == True: + horizontal_line = "mean" + + # Get colors of the box plots + palette = [ + patch.get_facecolor() + for patch in ax.patches + if isinstance(patch, mpatches.PathPatch)] + + add_horizontal_line_annot(ax, data, group_col, val_col, palette, horizontal_line) + + return ax #input type and annotation type as well. + return ax + + return wrapper \ No newline at end of file