diff --git a/pyproject.toml b/pyproject.toml index b62981a..551f41e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] dependencies = [ "matplotlib", - "numpy", + "numpy>=1.20", ] dynamic = ["version"] diff --git a/src/daft/__init__.py b/src/daft/__init__.py index 0ae9017..979390c 100644 --- a/src/daft/__init__.py +++ b/src/daft/__init__.py @@ -2,13 +2,20 @@ from importlib.metadata import version as get_distribution -from . import _core, _exceptions, _utils -from ._core import PGM, Node, Edge, Plate, Text -from ._exceptions import SameLocationError -from ._utils import _rendering_context, _pop_multiple +from . import exceptions, node, edge, pgm, plate, types, utils +from .pgm import PGM +from .node import Node +from .edge import Edge +from .plate import Plate, Text +from .exceptions import SameLocationError +from .utils import RenderingContext, _pop_multiple __version__ = get_distribution("daft") __all__ = [] -__all__ += _core.__all__ -__all__ += _exceptions.__all__ -__all__ += _utils.__all__ +__all__ += pgm.__all__ +__all__ += node.__all__ +__all__ += edge.__all__ +__all__ += plate.__all__ +__all__ += exceptions.__all__ +__all__ += utils.__all__ +__all__ += types.__all__ diff --git a/src/daft/_core.py b/src/daft/_core.py deleted file mode 100644 index 7a6516f..0000000 --- a/src/daft/_core.py +++ /dev/null @@ -1,1200 +0,0 @@ -"""Code for Daft""" - -__all__ = ["PGM", "Node", "Edge", "Plate"] -# TODO: should Text be added? - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.patches import Ellipse -from matplotlib.patches import FancyArrow -from matplotlib.patches import Rectangle - -import numpy as np - -from ._exceptions import SameLocationError -from ._utils import _rendering_context, _pop_multiple - -# pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines - - -class PGM: - """ - The base object for building a graphical model representation. - - :param shape: (optional) - The number of rows and columns in the grid. Will automatically - determine is not provided. - - :param origin: (optional) - The coordinates of the bottom left corner of the plot. Will - automatically determine if not provided. - - :param grid_unit: (optional) - The size of the grid spacing measured in centimeters. - - :param node_unit: (optional) - The base unit for the node size. This is a number in centimeters that - sets the default diameter of the nodes. - - :param observed_style: (optional) - How should the "observed" nodes be indicated? This must be one of: - ``"shaded"``, ``"inner"`` or ``"outer"`` where ``inner`` and - ``outer`` nodes are shown as double circles with the second circle - plotted inside or outside of the standard one, respectively. - - :param alternate_style: (optional) - How should the "alternate" nodes be indicated? This must be one of: - ``"shaded"``, ``"inner"`` or ``"outer"`` where ``inner`` and - ``outer`` nodes are shown as double circles with the second circle - plotted inside or outside of the standard one, respectively. - - :param node_ec: (optional) - The default edge color for the nodes. - - :param node_fc: (optional) - The default face color for the nodes. - - :param plate_fc: (optional) - The default face color for plates. - - :param directed: (optional) - Should the edges be directed by default? - - :param aspect: (optional) - The default aspect ratio for the nodes. - - :param label_params: (optional) - Default node label parameters. See :class:`PGM.Node` for details. - - :param dpi: (optional) - Set DPI for display and saving files. - - """ - - def __init__( - self, - shape=None, - origin=None, - grid_unit=2.0, - node_unit=1.0, - observed_style="shaded", - alternate_style="inner", - line_width=1.0, - node_ec="k", - node_fc="w", - plate_fc="w", - directed=True, - aspect=1.0, - label_params=None, - dpi=None, - ): - self._nodes = {} - self._edges = [] - self._plates = [] - self._dpi = dpi - - # if shape and origin are not given, pass a default - # and we will determine at rendering time - self.shape = shape - self.origin = origin - if shape is None: - shape = [1, 1] - if origin is None: - origin = [0, 0] - - self._ctx = _rendering_context( - shape=shape, - origin=origin, - grid_unit=grid_unit, - node_unit=node_unit, - observed_style=observed_style, - alternate_style=alternate_style, - line_width=line_width, - node_ec=node_ec, - node_fc=node_fc, - plate_fc=plate_fc, - directed=directed, - aspect=aspect, - label_params=label_params, - dpi=dpi, - ) - - def __enter__(self): - return self - - def __exit__(self, *args): - self._ctx.close() - - def add_node( - self, - node, - content="", - x=0, - y=0, - scale=1.0, - aspect=None, - observed=False, - fixed=False, - alternate=False, - offset=(0.0, 0.0), - fontsize=None, - plot_params=None, - label_params=None, - shape="ellipse", - ): - """ - Add a :class:`Node` to the model. - - :param node: - The plain-text identifier for the nodeself. - Can also be the :class:`Node` to retain backward compatibility. - - :param content: - The display form of the variable. - - :param x: - The x-coordinate of the node in *model units*. - - :param y: - The y-coordinate of the node. - - :param scale: (optional) - The diameter (or height) of the node measured in multiples of - ``node_unit`` as defined by the :class:`PGM` object. - - :param aspect: (optional) - The aspect ratio width/height for elliptical nodes; default 1. - - :param observed: (optional) - Should this be a conditioned variable? - - :param fixed: (optional) - Should this be a fixed (not permitted to vary) variable? - If `True`, modifies or over-rides ``diameter``, ``offset``, - ``facecolor``, and a few other ``plot_params`` settings. - This setting conflicts with ``observed``. - - :param alternate: (optional) - Should this use the alternate style? - - :param offset: (optional) - The ``(dx, dy)`` offset of the label (in points) from the default - centered position. - - :param fontsize: (optional) - The fontsize to use. - - :param plot_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.Ellipse` constructor. - - :param label_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.text.Annotation` constructor. Any kwargs not - used by Annontation get passed to :class:`matplotlib.text.Text`. - - :param shape: (optional) - String in {ellipse (default), rectangle} - If rectangle, aspect and scale holds for rectangle - - """ - if isinstance(node, Node): - _node = node - else: - _node = Node( - node, - content, - x, - y, - scale, - aspect, - observed, - fixed, - alternate, - offset, - fontsize, - plot_params, - label_params, - shape, - ) - - self._nodes[_node.name] = _node - - return node - - def add_edge( - self, - name1, - name2, - directed=None, - xoffset=0.0, - yoffset=0.1, - label=None, - plot_params=None, - label_params=None, - **kwargs, # pylint: disable=unused-argument - ): - """ - Construct an :class:`Edge` between two named :class:`Node` objects. - - :param name1: - The name identifying the first node. - - :param name2: - The name identifying the second node. If the edge is directed, - the arrow will point to this node. - - :param directed: (optional) - Should the edge be directed from ``node1`` to ``node2``? In other - words: should it have an arrow? - - :param label: (optional) - A string to annotate the edge. - - :param xoffset: (optional) - The x-offset from the middle of the arrow to plot the label. - Only takes effect if `label` is defined in `plot_params`. - - :param yoffset: (optional) - The y-offset from the middle of the arrow to plot the label. - Only takes effect if `label` is defined in `plot_params`. - - :param plot_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.FancyArrow` constructor. - - :param label_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.axes.Axes.annotate` constructor. - - """ - if directed is None: - directed = self._ctx.directed - - e = Edge( - self._nodes[name1], - self._nodes[name2], - directed=directed, - label=label, - xoffset=xoffset, - yoffset=yoffset, - plot_params=plot_params, - label_params=label_params, - ) - self._edges.append(e) - - return e - - def add_plate( - self, - plate, - label=None, - label_offset=(5, 5), - shift=0, - position="bottom left", - fontsize=None, - rect_params=None, - bbox=None, - ): - """ - Add a :class:`Plate` object to the model. - - :param plate: - The rectangle describing the plate bounds in model coordinates. - Can also be the :class:`Plate` to retain backward compatibility. - - :param label: (optional) - A string to annotate the plate. - - :param label_offset: (optional) - The x and y offsets of the label text measured in points. - - :param shift: (optional) - The vertical "shift" of the plate measured in model units. This - will move the bottom of the panel by ``shift`` units. - - :param position: (optional) - One of ``"{vertical} {horizontal}"`` where vertical is ``"bottom"`` - or ``"middle"`` or ``"top"`` and horizontal is ``"left"`` or - ``"center"`` or ``"right"``. - - :param fontsize: (optional) - The fontsize to use. - - :param rect_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.Rectangle` constructor. - - """ - if isinstance(plate, Plate): - _plate = plate - else: - _plate = Plate( - plate, - label, - label_offset, - shift, - position, - fontsize, - rect_params, - bbox, - ) - - self._plates.append(_plate) - - def add_text(self, x, y, label, fontsize=None): - """ - A subclass of plate to writing text using grid coordinates. Any - ``**kwargs`` are passed through to :class:`PGM.Plate`. - - :param x: - The x-coordinate of the text in *model units*. - - :param y: - The y-coordinate of the text. - - :param label: - A string to write. - - :param fontsize: (optional) - The fontsize to use. - - """ - - text = Text(x=x, y=y, label=label, fontsize=fontsize) - self._plates.append(text) - - return None - - def render(self, dpi=None): - """ - Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in - the model. This will create a new figure with the correct dimensions - and plot the model in this area. - - :param dpi: (optional) - The DPI value to use for rendering. - - """ - - if dpi is None: - self._ctx.dpi = self._dpi - else: - self._ctx.dpi = dpi - - def get_max(maxsize, artist): - if isinstance(artist, Ellipse): - maxsize = np.maximum( - maxsize, - artist.center - + np.array([artist.width, artist.height]) / 2, - dtype=np.float64, - ) - elif isinstance(artist, Rectangle): - maxsize = np.maximum( - maxsize, - np.array([artist._x0, artist._y0], dtype=np.float64) - + np.array([artist._width, artist._height]), - dtype=np.float64, - ) - return maxsize - - def get_min(minsize, artist): - if isinstance(artist, Ellipse): - minsize = np.minimum( - minsize, - artist.center - - np.array([artist.width, artist.height]) / 2, - dtype=np.float64, - ) - elif isinstance(artist, Rectangle): - minsize = np.minimum( - minsize, - np.array([artist._x0, artist._y0], dtype=np.float64), - ) - return minsize - - # Auto-set shape - # We pass through each object once to find the maximum coordinates - if self.shape is None: - maxsize = np.copy(self._ctx.origin) - - for plate in self._plates: - artist = plate.render(self._ctx) - maxsize = get_max(maxsize, artist) - - for name in self._nodes: - if self._nodes[name].fixed: - self._nodes[name].offset[1] -= 12.5 - artist = self._nodes[name].render(self._ctx) - maxsize = get_max(maxsize, artist) - - self._ctx.reset_shape(maxsize) - - # Pass through each object to find the minimum coordinates - if self.origin is None: - minsize = np.copy(self._ctx.shape * self._ctx.grid_unit) - - for plate in self._plates: - artist = plate.render(self._ctx) - minsize = get_min(minsize, artist) - - for name in self._nodes: - artist = self._nodes[name].render(self._ctx) - minsize = get_min(minsize, artist) - - self._ctx.reset_origin(minsize, self.shape is None) - - # Clear the figure from rendering context - self._ctx.reset_figure() - - for plate in self._plates: - plate.render(self._ctx) - - for edge in self._edges: - edge.render(self._ctx) - - for name in self._nodes: - self._nodes[name].render(self._ctx) - - return self.ax - - @property - def figure(self): - """Figure as a property.""" - return self._ctx.figure() - - @property - def ax(self): - """Axes as a property.""" - return self._ctx.ax() - - def show(self, *args, dpi=None, **kwargs): - """ - Wrapper on :class:`PGM.render()` that calls `matplotlib.show()` - immediately after. - - :param dpi: (optional) - The DPI value to use for rendering. - - """ - - self.render(dpi=dpi) - plt.show(*args, **kwargs) - - def savefig(self, fname, *args, **kwargs): - """ - Wrapper on ``matplotlib.Figure.savefig()`` that sets default image - padding using ``bbox_inchaes = tight``. - ``*args`` and ``**kwargs`` are passed to `matplotlib.Figure.savefig()`. - - :param fname: - The filename to save as. - - :param dpi: (optional) - The DPI value to use for saving. - - """ - kwargs["bbox_inches"] = kwargs.get("bbox_inches", "tight") - kwargs["dpi"] = kwargs.get("dpi", self._dpi) - if not self.figure: - self.render() - self.figure.savefig(fname, *args, **kwargs) - - -class Node: - """ - The representation of a random variable in a :class:`PGM`. - - :param name: - The plain-text identifier for the node. - - :param content: - The display form of the variable. - - :param x: - The x-coordinate of the node in *model units*. - - :param y: - The y-coordinate of the node. - - :param scale: (optional) - The diameter (or height) of the node measured in multiples of - ``node_unit`` as defined by the :class:`PGM` object. - - :param aspect: (optional) - The aspect ratio width/height for elliptical nodes; default 1. - - :param observed: (optional) - Should this be a conditioned variable? - - :param fixed: (optional) - Should this be a fixed (not permitted to vary) variable? - If `True`, modifies or over-rides ``diameter``, ``offset``, - ``facecolor``, and a few other ``plot_params`` settings. - This setting conflicts with ``observed``. - - :param alternate: (optional) - Should this use the alternate style? - - :param offset: (optional) - The ``(dx, dy)`` offset of the label (in points) from the default - centered position. - - :param fontsize: (optional) - The fontsize to use. - - :param plot_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.Ellipse` constructor. - - :param label_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.text.Annotation` constructor. Any kwargs not - used by Annontation get passed to :class:`matplotlib.text.Text`. - - :param shape: (optional) - String in {ellipse (default), rectangle} - If rectangle, aspect and scale holds for rectangle. - - """ - - def __init__( - self, - name, - content, - x, - y, - scale=1.0, - aspect=None, - observed=False, - fixed=False, - alternate=False, - offset=(0.0, 0.0), - fontsize=None, - plot_params=None, - label_params=None, - shape="ellipse", - ): - # Check Node style. - # Iterable is consumed, so first condition checks if two or more are - # true - node_style = iter((observed, alternate, fixed)) - if not ( - (any(node_style) and not any(node_style)) - or not any((observed, alternate, fixed)) - ): - msg = "A node cannot be more than one of `observed`, `fixed`, or `alternate`." - raise ValueError(msg) - - self.observed = observed - self.fixed = fixed - self.alternate = alternate - - # Metadata. - self.name = name - self.content = content - - # Coordinates and dimensions. - self.x, self.y = float(x), float(y) - self.scale = float(scale) - if self.fixed: - self.scale /= 6.0 - if aspect is not None: - self.aspect = float(aspect) - else: - self.aspect = aspect - - # Set fontsize - self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"] - - # Display parameters. - self.plot_params = dict(plot_params) if plot_params else {} - - # Text parameters. - self.offset = list(offset) - self.label_params = dict(label_params) if label_params else None - - # Shape - if shape in ["ellipse", "rectangle"]: - self.shape = shape - else: - print("Warning: wrong shape value, set to ellipse instead") - self.shape = "ellipse" - - def render(self, ctx): - """ - Render the node. - - :param ctx: - The :class:`_rendering_context` object. - - """ - # Get the axes and default plotting parameters from the rendering - # context. - ax = ctx.ax() - - # Resolve the plotting parameters. - plot_params = dict(self.plot_params) - - plot_params["lw"] = _pop_multiple( - plot_params, ctx.line_width, "lw", "linewidth" - ) - - plot_params["ec"] = plot_params["edgecolor"] = _pop_multiple( - plot_params, ctx.node_ec, "ec", "edgecolor" - ) - - fc_is_set = "fc" in plot_params or "facecolor" in plot_params - plot_params["fc"] = _pop_multiple( - plot_params, ctx.node_fc, "fc", "facecolor" - ) - fc = plot_params["fc"] - - plot_params["alpha"] = plot_params.get("alpha", 1) - - # And the label parameters. - if self.label_params is None: - label_params = dict(ctx.label_params) - else: - label_params = dict(self.label_params) - - label_params["va"] = _pop_multiple( - label_params, "center", "va", "verticalalignment" - ) - - label_params["ha"] = _pop_multiple( - label_params, "center", "ha", "horizontalalignment" - ) - - # Deal with ``fixed`` nodes. - scale = self.scale - if self.fixed: - # MAGIC: These magic numbers should depend on the grid/node units. - self.offset[1] += 6 - - label_params["va"] = "baseline" - label_params.pop("verticalalignment", None) - label_params.pop("ma", None) - - if not fc_is_set: - plot_params["fc"] = "k" - - diameter = ctx.node_unit * scale - if self.aspect is not None: - aspect = self.aspect - else: - aspect = ctx.aspect - - # Set up an observed node or alternate node. Note the fc INSANITY. - if self.observed and not self.fixed: - style = ctx.observed_style - elif self.alternate and not self.fixed: - style = ctx.alternate_style - else: - style = False - - if style: - # Update the plotting parameters depending on the style of - # observed node. - h = float(diameter) - w = aspect * float(diameter) - if style == "shaded": - plot_params["fc"] = "0.7" - elif style == "outer": - h = diameter + 0.1 * diameter - w = aspect * diameter + 0.1 * diameter - elif style == "inner": - h = diameter - 0.1 * diameter - w = aspect * diameter - 0.1 * diameter - plot_params["fc"] = fc - - # Draw the background ellipse. - if self.shape == "ellipse": - bg = Ellipse( - xy=ctx.convert(self.x, self.y), - width=w, - height=h, - **plot_params, - ) - elif self.shape == "rectangle": - # Adapt to make Rectangle the same api than ellipse - wi = w - xy = ctx.convert(self.x, self.y) - xy[0] = xy[0] - wi / 2.0 - xy[1] = xy[1] - h / 2.0 - - bg = Rectangle(xy=xy, width=wi, height=h, **plot_params) - else: - # Should never append - raise ( - ValueError( - "Wrong shape in object causes an error in render" - ) - ) - - ax.add_artist(bg) - - # Reset the face color. - plot_params["fc"] = fc - - # Draw the foreground ellipse. - if not fc_is_set and not self.fixed and self.observed: - plot_params["fc"] = "none" - - if self.shape == "ellipse": - el = Ellipse( - xy=ctx.convert(self.x, self.y), - width=diameter * aspect, - height=diameter, - **plot_params, - ) - elif self.shape == "rectangle": - # Adapt to make Rectangle the same api than ellipse - wi = diameter * aspect - xy = ctx.convert(self.x, self.y) - xy[0] = xy[0] - wi / 2.0 - xy[1] = xy[1] - diameter / 2.0 - - el = Rectangle(xy=xy, width=wi, height=diameter, **plot_params) - else: - # Should never append - raise ( - ValueError("Wrong shape in object causes an error in render") - ) - - ax.add_artist(el) - - # Reset the face color. - plot_params["fc"] = fc - - # Annotate the node. - ax.annotate( - self.content, - ctx.convert(self.x, self.y), - xycoords="data", - xytext=self.offset, - textcoords="offset points", - size=self.fontsize, - **label_params, - ) - - return el - - def get_frontier_coord(self, target_xy, ctx, edge): - """ - Get the coordinates of the point of intersection between the - shape of the node and a line starting from the center of the node to an - arbitrary point. Will throw a :class:`SameLocationError` if the nodes - contain the same `x` and `y` coordinates. See the example of rectangle - below: - - .. code-block:: python - - _____________ - | | ____--X (target_node) - | __--X---- - | X-- |(return coordinate of this point) - | | - |____________| - - :target_xy: (x float, y float) - A tuple of coordinate of target node - - """ - - # Scale the coordinates appropriately. - x1, y1 = ctx.convert(self.x, self.y) - x2, y2 = target_xy[0], target_xy[1] - - # Aspect ratios. - if self.aspect is not None: - aspect = self.aspect - else: - aspect = ctx.aspect - - if self.shape == "ellipse": - # Compute the distances. - dx, dy = x2 - x1, y2 - y1 - if dx == 0.0 and dy == 0.0: - raise SameLocationError(edge) - dist1 = np.sqrt(dy * dy + dx * dx / float(aspect * aspect)) - - # Compute the fractional effect of the radii of the nodes. - alpha1 = 0.5 * ctx.node_unit * self.scale / dist1 - - # Get the coordinates of the starting position. - x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy - - return x0, y0 - - elif self.shape == "rectangle": - dx, dy = x2 - x1, y2 - y1 - - # theta = np.angle(complex(dx, dy)) - # print(theta) - # left or right intersection - dxx1 = self.scale * aspect / 2.0 * (np.sign(dx) or 1.0) - dyy1 = ( - self.scale - * aspect - / 2.0 - * np.abs(dy / dx) - * (np.sign(dy) or 1.0) - ) - val1 = np.abs(complex(dxx1, dyy1)) - - # up or bottom intersection - dxx2 = self.scale * 0.5 * np.abs(dx / dy) * (np.sign(dx) or 1.0) - dyy2 = self.scale * 0.5 * (np.sign(dy) or 1.0) - val2 = np.abs(complex(dxx2, dyy2)) - - if val1 < val2: - return x1 + dxx1, y1 + dyy1 - else: - return x1 + dxx2, y1 + dyy2 - - else: - # Should never append - raise ValueError("Wrong shape in object causes an error") - - -class Edge: - """ - An edge between two :class:`Node` objects. - - :param node1: - The first :class:`Node`. - - :param node2: - The second :class:`Node`. The arrow will point towards this node. - - :param directed: (optional) - Should the edge be directed from ``node1`` to ``node2``? In other - words: should it have an arrow? - - :param label: (optional) - A string to annotate the edge. - - :param xoffset: (optional) - The x-offset from the middle of the arrow to plot the label. - Only takes effect if `label` is defined in `plot_params`. - - :param yoffset: (optional) - The y-offset from the middle of the arrow to plot the label. - Only takes effect if `label` is defined in `plot_params`. - - :param plot_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.FancyArrow` constructor to adjust - edge behavior. - - :param label_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.axes.Axes.annotate` constructor to adjust - label behavior. - - """ - - def __init__( - self, - node1, - node2, - directed=True, - label=None, - xoffset=0, - yoffset=0.1, - plot_params=None, - label_params=None, - ): - self.node1 = node1 - self.node2 = node2 - self.directed = directed - self.label = label - self.xoffset = xoffset - self.yoffset = yoffset - self.plot_params = dict(plot_params) if plot_params else {} - self.label_params = dict(label_params) if label_params else {} - - def _get_coords(self, ctx): - """ - Get the coordinates of the line. - - :param conv: - A callable coordinate conversion. - - :returns: - * ``x0``, ``y0``: the coordinates of the start of the line. - * ``dx0``, ``dy0``: the displacement vector. - - """ - # Scale the coordinates appropriately. - x1, y1 = ctx.convert(self.node1.x, self.node1.y) - x2, y2 = ctx.convert(self.node2.x, self.node2.y) - - x3, y3 = self.node1.get_frontier_coord((x2, y2), ctx, self) - x4, y4 = self.node2.get_frontier_coord((x1, y1), ctx, self) - - return x3, y3, x4 - x3, y4 - y3 - - def render(self, ctx): - """ - Render the edge in the given axes. - - :param ctx: - The :class:`_rendering_context` object. - - """ - ax = ctx.ax() - - plot_params = self.plot_params - plot_params["linewidth"] = _pop_multiple( - plot_params, ctx.line_width, "lw", "linewidth" - ) - - plot_params["linestyle"] = plot_params.get("linestyle", "-") - - # Add edge annotation. - if self.label is not None: - x, y, dx, dy = self._get_coords(ctx) - ax.annotate( - self.label, - [x + 0.5 * dx + self.xoffset, y + 0.5 * dy + self.yoffset], - xycoords="data", - xytext=[0, 3], - textcoords="offset points", - ha="center", - va="center", - **self.label_params, - ) - - if self.directed: - plot_params["ec"] = _pop_multiple( - plot_params, "k", "ec", "edgecolor" - ) - plot_params["fc"] = _pop_multiple( - plot_params, "k", "fc", "facecolor" - ) - plot_params["head_length"] = plot_params.get("head_length", 0.25) - plot_params["head_width"] = plot_params.get("head_width", 0.1) - - # Build an arrow. - args = self._get_coords(ctx) - - # zero lengh arrow produce error - if not (args[2] == 0.0 and args[3] == 0.0): - ar = FancyArrow( - *self._get_coords(ctx), - width=0, - length_includes_head=True, - **plot_params, - ) - - # Add the arrow to the axes. - ax.add_artist(ar) - return ar - - else: - print(args[2], args[3]) - - else: - plot_params["color"] = plot_params.get("color", "k") - - # Get the right coordinates. - x, y, dx, dy = self._get_coords(ctx) - - # Plot the line. - line = ax.plot([x, x + dx], [y, y + dy], **plot_params) - return line - - -class Plate: - """ - A plate to encapsulate repeated independent processes in the model. - - :param rect: - The rectangle describing the plate bounds in model coordinates. - This is [x-start, y-start, x-length, y-length]. - - :param label: (optional) - A string to annotate the plate. - - :param label_offset: (optional) - The x- and y- offsets of the label text measured in points. - - :param shift: (optional) - The vertical "shift" of the plate measured in model units. This will - move the bottom of the panel by ``shift`` units. - - :param position: (optional) - One of ``"{vertical} {horizontal}"`` where vertical is ``"bottom"`` - or ``"middle"`` or ``"top"`` and horizontal is ``"left"`` or - ``"center"`` or ``"right"``. - - :param fontsize: (optional) - The fontsize to use. - - :param rect_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.Rectangle` constructor, which defines - the properties of the plate. - - :param bbox: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.axes.Axes.annotate` constructor, which defines - the box drawn around the text. - - """ - - def __init__( - self, - rect, - label=None, - label_offset=(5, 5), - shift=0, - position="bottom left", - fontsize=None, - rect_params=None, - bbox=None, - ): - self.rect = rect - self.label = label - self.label_offset = label_offset - self.shift = shift - - if fontsize is not None: - self.fontsize = fontsize - else: - self.fontsize = mpl.rcParams["font.size"] - - if rect_params is not None: - self.rect_params = dict(rect_params) - else: - self.rect_params = None - - if bbox is not None: - self.bbox = dict(bbox) - - # Set the awful default blue color to transparent - if "fc" not in self.bbox.keys(): - self.bbox["fc"] = "none" - else: - self.bbox = None - - self.position = position - - def render(self, ctx): - """ - Render the plate in the given axes. - - :param ctx: - The :class:`_rendering_context` object. - - """ - ax = ctx.ax() - - shift = np.array([0, self.shift], dtype=np.float64) - rect = np.atleast_1d(self.rect) - bottom_left = ctx.convert(*(rect[:2] + shift)) - top_right = ctx.convert(*(rect[:2] + rect[2:])) - rect = np.concatenate([bottom_left, top_right - bottom_left]) - - if self.rect_params is not None: - rect_params = self.rect_params - else: - rect_params = {} - - rect_params["ec"] = _pop_multiple(rect_params, "k", "ec", "edgecolor") - rect_params["fc"] = _pop_multiple( - rect_params, ctx.plate_fc, "fc", "facecolor" - ) - rect_params["lw"] = _pop_multiple( - rect_params, ctx.line_width, "lw", "linewidth" - ) - rectangle = Rectangle(rect[:2], *rect[2:], **rect_params) - - ax.add_artist(rectangle) - - if self.label is not None: - offset = np.array(self.label_offset, dtype=np.float64) - if "left" in self.position: - position = rect[:2] - ha = "left" - elif "right" in self.position: - position = rect[:2] - position[0] += rect[2] - ha = "right" - offset[0] = -offset[0] - elif "center" in self.position: - position = rect[:2] - position[0] = rect[2] / 2 + rect[0] - ha = "center" - else: - raise RuntimeError( - f"Unknown positioning string: {self.position}" - ) - - if "bottom" in self.position: - va = "bottom" - elif "top" in self.position: - position[1] = rect[1] + rect[3] - offset[1] = -offset[1] - 0.1 - va = "top" - elif "middle" in self.position: - position[1] += rect[3] / 2 - va = "center" - else: - raise RuntimeError( - f"Unknown positioning string: {self.position}" - ) - - ax.annotate( - self.label, - xy=position, - xycoords="data", - xytext=offset, - textcoords="offset points", - size=self.fontsize, - bbox=self.bbox, - horizontalalignment=ha, - verticalalignment=va, - ) - - return rectangle - - -class Text(Plate): - """ - A subclass of plate to writing text using grid coordinates. Any **kwargs - are passed through to :class:`PGM.Plate`. - - :param x: - The x-coordinate of the text in *model units*. - - :param y: - The y-coordinate of the text. - - :param label: - A string to write. - - :param fontsize: (optional) - The fontsize to use. - - """ - - def __init__(self, x, y, label, fontsize=None): - self.rect = [x, y, 0.0, 0.0] - self.label = label - self.fontsize = fontsize - self.label_offset = [0.0, 0.0] - self.bbox = {"fc": "none", "ec": "none"} - self.rect_params = {"ec": "none"} - - super().__init__( - rect=self.rect, - label=self.label, - label_offset=self.label_offset, - fontsize=self.fontsize, - rect_params=self.rect_params, - bbox=self.bbox, - ) diff --git a/src/daft/edge.py b/src/daft/edge.py new file mode 100644 index 0000000..21ede14 --- /dev/null +++ b/src/daft/edge.py @@ -0,0 +1,167 @@ +"""Edge""" + +__all__ = ["Edge"] + + +from matplotlib.lines import Line2D +from matplotlib.patches import FancyArrow + +from typing import Any, cast, Optional, Union + +from .utils import _pop_multiple, RenderingContext +from .types import Tuple4F, PlotParams, LabelParams + + +class Edge: + """ + An edge between two :class:`Node` objects. + + :param node1: + The first :class:`Node`. + + :param node2: + The second :class:`Node`. The arrow will point towards this node. + + :param directed: (optional) + Should the edge be directed from ``node1`` to ``node2``? In other + words: should it have an arrow? + + :param label: (optional) + A string to annotate the edge. + + :param xoffset: (optional) + The x-offset from the middle of the arrow to plot the label. + Only takes effect if `label` is defined in `plot_params`. + + :param yoffset: (optional) + The y-offset from the middle of the arrow to plot the label. + Only takes effect if `label` is defined in `plot_params`. + + :param plot_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.FancyArrow` constructor to adjust + edge behavior. + + :param label_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.axes.Axes.annotate` constructor to adjust + label behavior. + + """ + + def __init__( + self, + node1: "Node", + node2: "Node", + directed: bool = True, + label: Optional[str] = None, + xoffset: float = 0, + yoffset: float = 0.1, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, + ) -> None: + self.node1 = node1 + self.node2 = node2 + self.directed = directed + self.label = label + self.xoffset = xoffset + self.yoffset = yoffset + self.plot_params = dict(plot_params) if plot_params else {} + self.label_params = dict(label_params) if label_params else {} + + def _get_coords(self, ctx: RenderingContext) -> Tuple4F: + """ + Get the coordinates of the line. + + :param conv: + A callable coordinate conversion. + + :returns: + * ``x0``, ``y0``: the coordinates of the start of the line. + * ``dx0``, ``dy0``: the displacement vector. + + """ + # Scale the coordinates appropriately. + x1, y1 = ctx.convert(self.node1.x, self.node1.y) + x2, y2 = ctx.convert(self.node2.x, self.node2.y) + + x3, y3 = self.node1.get_frontier_coord((x2, y2), ctx, self) + x4, y4 = self.node2.get_frontier_coord((x1, y1), ctx, self) + + return x3, y3, x4 - x3, y4 - y3 + + def render(self, ctx: RenderingContext) -> Union[FancyArrow, list[Line2D]]: + """ + Render the edge in the given axes. + + :param ctx: + The :class:`_rendering_context` object. + + """ + ax = ctx.ax() + + plot_params = self.plot_params + plot_params["linewidth"] = _pop_multiple( + plot_params, ctx.line_width, "lw", "linewidth" + ) + + plot_params["linestyle"] = plot_params.get("linestyle", "-") + + # Add edge annotation. + if self.label is not None: + x, y, dx, dy = self._get_coords(ctx) + ax.annotate( + self.label, + xy=(x + 0.5 * dx + self.xoffset, y + 0.5 * dy + self.yoffset), + xycoords="data", + xytext=(0, 3), + textcoords="offset points", + ha="center", + va="center", + **cast(dict[str, Any], self.label_params), + ) + + if self.directed: + plot_params["ec"] = _pop_multiple( + plot_params, "k", "ec", "edgecolor" + ) + plot_params["fc"] = _pop_multiple( + plot_params, "k", "fc", "facecolor" + ) + plot_params["head_length"] = plot_params.get("head_length", 0.25) + plot_params["head_width"] = plot_params.get("head_width", 0.1) + + # Build an arrow. + args = self._get_coords(ctx) + + # zero lengh arrow produce error + if not (args[2] == 0.0 and args[3] == 0.0): + ar = FancyArrow( + *self._get_coords(ctx), + width=0, + length_includes_head=True, + **cast(dict[str, Any], plot_params), + ) + + # Add the arrow to the axes. + ax.add_artist(ar) + return ar + + else: + print(args[2], args[3]) + return [] + + else: + plot_params["color"] = plot_params.get("color", "k") + + # Get the right coordinates. + x, y, dx, dy = self._get_coords(ctx) + + # Plot the line. + line = ax.plot( + (x, x + dx), (y, y + dy), **cast(dict[str, Any], plot_params) + ) + return line + + +from .node import Node diff --git a/src/daft/_exceptions.py b/src/daft/exceptions.py similarity index 86% rename from src/daft/_exceptions.py rename to src/daft/exceptions.py index 605b7f4..753e615 100644 --- a/src/daft/_exceptions.py +++ b/src/daft/exceptions.py @@ -11,9 +11,12 @@ class SameLocationError(Exception): The Edge object whose nodes are being added. """ - def __init__(self, edge): + def __init__(self, edge: "Edge") -> None: self.message = ( "Attempted to add edge between `{}` and `{}` but they " + "share the same location." ).format(edge.node1.name, edge.node2.name) super().__init__(self.message) + + +from .edge import Edge diff --git a/src/daft/node.py b/src/daft/node.py new file mode 100644 index 0000000..cf32236 --- /dev/null +++ b/src/daft/node.py @@ -0,0 +1,398 @@ +"""Node""" + +__all__ = ["Node"] + +import matplotlib as mpl +from copy import deepcopy +from matplotlib.patches import Ellipse, Rectangle + +import numpy as np + +from typing import Any, cast, Optional, Union + +from .utils import _pop_multiple, RenderingContext +from .types import Tuple2F, PlotParams, LabelParams, Shape + + +class Node: + """ + The representation of a random variable in a :class:`PGM`. + + :param name: + The plain-text identifier for the node. + + :param content: + The display form of the variable. + + :param x: + The x-coordinate of the node in *model units*. + + :param y: + The y-coordinate of the node. + + :param scale: (optional) + The diameter (or height) of the node measured in multiples of + ``node_unit`` as defined by the :class:`PGM` object. + + :param aspect: (optional) + The aspect ratio width/height for elliptical nodes; default 1. + + :param observed: (optional) + Should this be a conditioned variable? + + :param fixed: (optional) + Should this be a fixed (not permitted to vary) variable? + If `True`, modifies or over-rides ``diameter``, ``offset``, + ``facecolor``, and a few other ``plot_params`` settings. + This setting conflicts with ``observed``. + + :param alternate: (optional) + Should this use the alternate style? + + :param offset: (optional) + The ``(dx, dy)`` offset of the label (in points) from the default + centered position. + + :param fontsize: (optional) + The fontsize to use. + + :param plot_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.Ellipse` constructor. + + :param label_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.text.Annotation` constructor. Any kwargs not + used by Annontation get passed to :class:`matplotlib.text.Text`. + + :param shape: (optional) + String in {ellipse (default), rectangle} + If rectangle, aspect and scale holds for rectangle. + + """ + + def __init__( + self, + name: str, + content: str, + x: float, + y: float, + scale: float = 1.0, + aspect: Optional[float] = None, + observed: bool = False, + fixed: bool = False, + alternate: bool = False, + offset: Tuple2F = (0.0, 0.0), + fontsize: Optional[float] = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, + shape: Shape = "ellipse", + ) -> None: + # Check Node style. + # Iterable is consumed, so first condition checks if two or more are + # true + node_style = iter((observed, alternate, fixed)) + if not ( + (any(node_style) and not any(node_style)) + or not any((observed, alternate, fixed)) + ): + msg = "A node cannot be more than one of `observed`, `fixed`, or `alternate`." + raise ValueError(msg) + + self.observed = observed + self.fixed = fixed + self.alternate = alternate + + # Metadata. + self.name = name + self.content = content + + # Coordinates and dimensions. + self.x = float(x) + self.y = float(y) + self.scale = float(scale) + if self.fixed: + self.scale /= 6.0 + if aspect is not None: + self.aspect: Optional[float] = float(aspect) + else: + self.aspect = aspect + + # Set fontsize + self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"] + + # Display parameters. + self.plot_params = cast( + PlotParams, dict(plot_params) if plot_params else {} + ) + + # Text parameters. + self.offset = offset + self.label_params = cast( + Optional[LabelParams], dict(label_params) if label_params else None + ) + + # Shape + if shape in ["ellipse", "rectangle"]: + self.shape = shape + else: + print("Warning: wrong shape value, set to ellipse instead") + self.shape = "ellipse" + + def render(self, ctx: RenderingContext) -> Union[Ellipse, Rectangle]: + """ + Render the node. + + :param ctx: + The :class:`_rendering_context` object. + + """ + # Get the axes and default plotting parameters from the rendering + # context. + ax = ctx.ax() + + # Resolve the plotting parameters. + plot_params = cast(PlotParams, dict(self.plot_params)) + + plot_params["lw"] = _pop_multiple( + cast(dict[str, Any], plot_params), + ctx.line_width, + "lw", + "linewidth", + ) + + plot_params["ec"] = plot_params["edgecolor"] = _pop_multiple( + cast(dict[str, Any], plot_params), ctx.node_ec, "ec", "edgecolor" + ) + + fc_is_set = "fc" in plot_params or "facecolor" in plot_params # type: ignore[unreachable] + plot_params["fc"] = _pop_multiple( + cast(dict[str, Any], plot_params), ctx.node_fc, "fc", "facecolor" + ) + fc = plot_params["fc"] + + plot_params["alpha"] = plot_params.get("alpha", 1) + + # And the label parameters. + if self.label_params is None: + label_params = deepcopy(ctx.label_params) + else: + label_params = deepcopy(self.label_params) + + label_params["va"] = _pop_multiple( + label_params, "center", "va", "verticalalignment" + ) + + label_params["ha"] = _pop_multiple( + label_params, "center", "ha", "horizontalalignment" + ) + + # Deal with ``fixed`` nodes. + scale = self.scale + if self.fixed: + # MAGIC: These magic numbers should depend on the grid/node units. + self.offset = (self.offset[0], self.offset[1] + 6) + + label_params["va"] = "baseline" + + if not fc_is_set: + plot_params["fc"] = "k" + + diameter = ctx.node_unit * scale + if self.aspect is not None: + aspect = self.aspect + else: + aspect = ctx.aspect + + # Set up an observed node or alternate node. Note the fc INSANITY. + if self.observed and not self.fixed: + style = ctx.observed_style + elif self.alternate and not self.fixed: + style = ctx.alternate_style + else: + style = "none" + + if style != "none": + # Update the plotting parameters depending on the style of + # observed node. + h = float(diameter) + w = aspect * float(diameter) + if style == "shaded": + plot_params["fc"] = "0.7" + elif style == "outer": + h = diameter + 0.1 * diameter + w = aspect * diameter + 0.1 * diameter + elif style == "inner": + h = diameter - 0.1 * diameter + w = aspect * diameter - 0.1 * diameter + plot_params["fc"] = fc + + # Draw the background ellipse. + if self.shape == "ellipse": + bg: Union[Ellipse, Rectangle] = Ellipse( + xy=ctx.convert(self.x, self.y), + width=w, + height=h, + **plot_params, + ) + elif self.shape == "rectangle": + # Adapt to make Rectangle the same api than ellipse + wi = w + x, y = ctx.convert(self.x, self.y) + x -= wi / 2.0 + y -= h / 2.0 + + bg = Rectangle( + xy=(x, y), + width=wi, + height=h, + **plot_params, + ) + else: + # Should never append + raise ( + ValueError( + "Wrong shape in object causes an error in render" + ) + ) + + ax.add_artist(bg) + + # Reset the face color. + plot_params["fc"] = fc + + # Draw the foreground ellipse. + if not fc_is_set and not self.fixed and self.observed: + plot_params["fc"] = "none" + + if self.shape == "ellipse": + el: Union[Ellipse, Rectangle] = Ellipse( + xy=ctx.convert(self.x, self.y), + width=diameter * aspect, + height=diameter, + **plot_params, + ) + elif self.shape == "rectangle": + # Adapt to make Rectangle the same api than ellipse + wi = diameter * aspect + x, y = ctx.convert(self.x, self.y) + x -= wi / 2.0 + y -= diameter / 2.0 + + el = Rectangle( + xy=(x, y), + width=wi, + height=diameter, + **plot_params, + ) + else: + # Should never append + raise ( + ValueError("Wrong shape in object causes an error in render") + ) + + ax.add_artist(el) + + # Reset the face color. + plot_params["fc"] = fc + + # pop extra params + _label_params = cast(dict[str, Any], label_params) + _label_params.pop("verticalalignment", None) + _label_params.pop("ma", None) + + # Annotate the node. + ax.annotate( + self.content, + ctx.convert(self.x, self.y), + xycoords="data", + xytext=self.offset, + textcoords="offset points", + size=self.fontsize, + **_label_params, + ) + + return el + + def get_frontier_coord( + self, target_xy: Tuple2F, ctx: RenderingContext, edge: "Edge" + ) -> Tuple2F: + """ + Get the coordinates of the point of intersection between the + shape of the node and a line starting from the center of the node to an + arbitrary point. Will throw a :class:`SameLocationError` if the nodes + contain the same `x` and `y` coordinates. See the example of rectangle + below: + + .. code-block:: python + + _____________ + | | ____--X (target_node) + | __--X---- + | X-- |(return coordinate of this point) + | | + |____________| + + :target_xy: (x float, y float) + A tuple of coordinate of target node + + """ + + # Scale the coordinates appropriately. + x1, y1 = ctx.convert(self.x, self.y) + x2, y2 = target_xy[0], target_xy[1] + + # Aspect ratios. + if self.aspect is not None: + aspect = self.aspect + else: + aspect = ctx.aspect + + if self.shape == "ellipse": + # Compute the distances. + dx, dy = x2 - x1, y2 - y1 + if dx == 0.0 and dy == 0.0: + raise SameLocationError(edge) + dist1 = np.sqrt(dy * dy + dx * dx / float(aspect * aspect)) + + # Compute the fractional effect of the radii of the nodes. + alpha1 = 0.5 * ctx.node_unit * self.scale / dist1 + + # Get the coordinates of the starting position. + x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy + + return x0, y0 + + elif self.shape == "rectangle": + dx, dy = x2 - x1, y2 - y1 + + # theta = np.angle(complex(dx, dy)) + # print(theta) + # left or right intersection + dxx1 = self.scale * aspect / 2.0 * (np.sign(dx) or 1.0) + dyy1 = ( + self.scale + * aspect + / 2.0 + * np.abs(dy / dx) + * (np.sign(dy) or 1.0) + ) + val1 = np.abs(complex(dxx1, dyy1)) + + # up or bottom intersection + dxx2 = self.scale * 0.5 * np.abs(dx / dy) * (np.sign(dx) or 1.0) + dyy2 = self.scale * 0.5 * (np.sign(dy) or 1.0) + val2 = np.abs(complex(dxx2, dyy2)) + + if val1 < val2: + return x1 + dxx1, y1 + dyy1 + else: + return x1 + dxx2, y1 + dyy2 + + else: + # Should never append + raise ValueError("Wrong shape in object causes an error") + + +from .edge import Edge +from .exceptions import SameLocationError diff --git a/src/daft/pgm.py b/src/daft/pgm.py new file mode 100644 index 0000000..0b802f8 --- /dev/null +++ b/src/daft/pgm.py @@ -0,0 +1,529 @@ +"""Code for Daft""" + +__all__ = ["PGM"] +# TODO: should Text be added? + +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse, Rectangle + +import numpy as np + +from typing import Any, Optional, Union + +from .node import Node +from .edge import Edge +from .plate import Plate, Text +from .utils import RenderingContext +from .types import ( + Tuple2F, + NDArrayF, + Shape, + Position, + CTX_Kwargs, + PlotParams, + LabelParams, + RectParams, +) + +# pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines + + +class PGM: + """ + The base object for building a graphical model representation. + + :param shape: (optional) + The number of rows and columns in the grid. Will automatically + determine is not provided. + + :param origin: (optional) + The coordinates of the bottom left corner of the plot. Will + automatically determine if not provided. + + :param grid_unit: (optional) + The size of the grid spacing measured in centimeters. + + :param node_unit: (optional) + The base unit for the node size. This is a number in centimeters that + sets the default diameter of the nodes. + + :param observed_style: (optional) + How should the "observed" nodes be indicated? This must be one of: + ``"shaded"``, ``"inner"`` or ``"outer"`` where ``inner`` and + ``outer`` nodes are shown as double circles with the second circle + plotted inside or outside of the standard one, respectively. + + :param alternate_style: (optional) + How should the "alternate" nodes be indicated? This must be one of: + ``"shaded"``, ``"inner"`` or ``"outer"`` where ``inner`` and + ``outer`` nodes are shown as double circles with the second circle + plotted inside or outside of the standard one, respectively. + + :param node_ec: (optional) + The default edge color for the nodes. + + :param node_fc: (optional) + The default face color for the nodes. + + :param plate_fc: (optional) + The default face color for plates. + + :param directed: (optional) + Should the edges be directed by default? + + :param aspect: (optional) + The default aspect ratio for the nodes. + + :param label_params: (optional) + Default node label parameters. See :class:`PGM.Node` for details. + + :param dpi: (optional) + Set DPI for display and saving files. + + """ + + def __init__( + self, + shape: Optional[Tuple2F] = None, + origin: Optional[Tuple2F] = None, + grid_unit: float = 2.0, + node_unit: float = 1.0, + observed_style: str = "shaded", + alternate_style: str = "inner", + line_width: float = 1.0, + node_ec: str = "k", + node_fc: str = "w", + plate_fc: str = "w", + directed: bool = True, + aspect: float = 1.0, + label_params: Optional[LabelParams] = None, + dpi: Optional[int] = None, + ) -> None: + self._nodes: dict[str, Node] = {} + self._edges: list[Edge] = [] + self._plates: list[Plate] = [] + self._dpi = dpi + + # if shape and origin are not given, pass a default + # and we will determine at rendering time + if shape is None: + _shape: Tuple2F = (1, 1) + self.shape = None + else: + _shape = shape + self.shape = tuple(shape) + + if origin is None: + _origin: Tuple2F = (0, 0) + self.origin = None + else: + _origin = origin + self.origin = tuple(origin) + + self._ctx = RenderingContext( + CTX_Kwargs( + shape=np.asarray(_shape, dtype=np.float64), + origin=np.asarray(_origin, dtype=np.float64), + grid_unit=grid_unit, + node_unit=node_unit, + observed_style=observed_style, + alternate_style=alternate_style, + line_width=line_width, + node_ec=node_ec, + node_fc=node_fc, + plate_fc=plate_fc, + directed=directed, + aspect=aspect, + label_params=label_params, + dpi=dpi, + ) + ) + + def __enter__(self) -> "PGM": + return self + + def __exit__(self, *args: Any) -> None: + self._ctx.close() + + def add_node( + self, + node: Node, + content: str = "", + x: float = 0, + y: float = 0, + scale: float = 1.0, + aspect: Optional[float] = None, + observed: bool = False, + fixed: bool = False, + alternate: bool = False, + offset: Tuple2F = (0, 0), + fontsize: Optional[float] = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, + shape: Shape = "ellipse", + ) -> Node: + """ + Add a :class:`Node` to the model. + + :param node: + The plain-text identifier for the nodeself. + Can also be the :class:`Node` to retain backward compatibility. + + :param content: + The display form of the variable. + + :param x: + The x-coordinate of the node in *model units*. + + :param y: + The y-coordinate of the node. + + :param scale: (optional) + The diameter (or height) of the node measured in multiples of + ``node_unit`` as defined by the :class:`PGM` object. + + :param aspect: (optional) + The aspect ratio width/height for elliptical nodes; default 1. + + :param observed: (optional) + Should this be a conditioned variable? + + :param fixed: (optional) + Should this be a fixed (not permitted to vary) variable? + If `True`, modifies or over-rides ``diameter``, ``offset``, + ``facecolor``, and a few other ``plot_params`` settings. + This setting conflicts with ``observed``. + + :param alternate: (optional) + Should this use the alternate style? + + :param offset: (optional) + The ``(dx, dy)`` offset of the label (in points) from the default + centered position. + + :param fontsize: (optional) + The fontsize to use. + + :param plot_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.Ellipse` constructor. + + :param label_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.text.Annotation` constructor. Any kwargs not + used by Annontation get passed to :class:`matplotlib.text.Text`. + + :param shape: (optional) + String in {ellipse (default), rectangle} + If rectangle, aspect and scale holds for rectangle + + """ + if isinstance(node, Node): + _node = node + else: + _node = Node( # type: ignore[unreachable] + node, + content, + x, + y, + scale, + aspect, + observed, + fixed, + alternate, + offset, + fontsize, + plot_params, + label_params, + shape, + ) + + self._nodes[_node.name] = _node + + return node + + def add_edge( + self, + name1: str, + name2: str, + directed: Optional[bool] = None, + xoffset: float = 0.0, + yoffset: float = 0.1, + label: Optional[str] = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, + **kwargs: dict[str, Any], # pylint: disable=unused-argument + ) -> Edge: + """ + Construct an :class:`Edge` between two named :class:`Node` objects. + + :param name1: + The name identifying the first node. + + :param name2: + The name identifying the second node. If the edge is directed, + the arrow will point to this node. + + :param directed: (optional) + Should the edge be directed from ``node1`` to ``node2``? In other + words: should it have an arrow? + + :param label: (optional) + A string to annotate the edge. + + :param xoffset: (optional) + The x-offset from the middle of the arrow to plot the label. + Only takes effect if `label` is defined in `plot_params`. + + :param yoffset: (optional) + The y-offset from the middle of the arrow to plot the label. + Only takes effect if `label` is defined in `plot_params`. + + :param plot_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.FancyArrow` constructor. + + :param label_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.axes.Axes.annotate` constructor. + + """ + if directed is None: + directed = self._ctx.directed + + e = Edge( + self._nodes[name1], + self._nodes[name2], + directed=directed, + label=label, + xoffset=xoffset, + yoffset=yoffset, + plot_params=plot_params, + label_params=label_params, + ) + self._edges.append(e) + + return e + + def add_plate( + self, + plate: Plate, + label: Optional[str] = None, + label_offset: Tuple2F = (5, 5), + shift: float = 0, + position: Position = "bottom left", + fontsize: Optional[float] = None, + rect_params: Optional[RectParams] = None, + bbox: Optional[bool] = None, + ) -> None: + """ + Add a :class:`Plate` object to the model. + + :param plate: + The rectangle describing the plate bounds in model coordinates. + Can also be the :class:`Plate` to retain backward compatibility. + + :param label: (optional) + A string to annotate the plate. + + :param label_offset: (optional) + The x and y offsets of the label text measured in points. + + :param shift: (optional) + The vertical "shift" of the plate measured in model units. This + will move the bottom of the panel by ``shift`` units. + + :param position: (optional) + One of ``"{vertical} {horizontal}"`` where vertical is ``"bottom"`` + or ``"middle"`` or ``"top"`` and horizontal is ``"left"`` or + ``"center"`` or ``"right"``. + + :param fontsize: (optional) + The fontsize to use. + + :param rect_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.Rectangle` constructor. + + """ + if isinstance(plate, Plate): + _plate = plate + else: + _plate = Plate( # type: ignore[unreachable] + plate, + label, + label_offset, + shift, + position, + fontsize, + rect_params, + bbox, + ) + + self._plates.append(_plate) + + def add_text( + self, x: float, y: float, label: str, fontsize: Optional[float] = None + ) -> None: + """ + A subclass of plate to writing text using grid coordinates. Any + ``**kwargs`` are passed through to :class:`PGM.Plate`. + + :param x: + The x-coordinate of the text in *model units*. + + :param y: + The y-coordinate of the text. + + :param label: + A string to write. + + :param fontsize: (optional) + The fontsize to use. + + """ + + text = Text(x=x, y=y, label=label, fontsize=fontsize) + self._plates.append(text) + + return None + + def render(self, dpi: Optional[int] = None) -> plt.Axes: + """ + Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in + the model. This will create a new figure with the correct dimensions + and plot the model in this area. + + :param dpi: (optional) + The DPI value to use for rendering. + + """ + + if dpi is None: + self._ctx.dpi = self._dpi + else: + self._ctx.dpi = dpi + + def get_max( + maxsize: NDArrayF, patch: Union[Ellipse, Rectangle] + ) -> NDArrayF: + if isinstance(patch, Ellipse): + maxsize = np.maximum( + maxsize, + patch.center + np.array([patch.width, patch.height]) / 2, + dtype=np.float64, + ) + elif isinstance(patch, Rectangle): + maxsize = np.maximum( + maxsize, + np.array([patch._x0, patch._y0], dtype=np.float64) # type: ignore[attr-defined] + + np.array([patch._width, patch._height]), # type: ignore[attr-defined] + dtype=np.float64, + ) + return maxsize + + def get_min( + minsize: NDArrayF, patch: Union[Ellipse, Rectangle] + ) -> NDArrayF: + if isinstance(patch, Ellipse): + minsize = np.minimum( + minsize, + patch.center - np.array([patch.width, patch.height]) / 2, + dtype=np.float64, + ) + elif isinstance(patch, Rectangle): + minsize = np.minimum( + minsize, + np.array([patch._x0, patch._y0], dtype=np.float64), # type: ignore[attr-defined] + ) + return minsize + + # Auto-set shape + # We pass through each object once to find the maximum coordinates + if self.shape is None: + maxsize = np.copy(self._ctx.origin) + + for plate in self._plates: + artist: Union[Ellipse, Rectangle] = plate.render(self._ctx) + maxsize = get_max(maxsize, artist) + + for name in self._nodes: + if self._nodes[name].fixed: + offx, offy = self._nodes[name].offset + offy -= 12.5 + self._nodes[name].offset = (offx, offy) + + artist = self._nodes[name].render(self._ctx) + maxsize = get_max(maxsize, artist) + + self._ctx.reset_shape(maxsize) + + # Pass through each object to find the minimum coordinates + if self.origin is None: + minsize = np.copy(self._ctx.shape * self._ctx.grid_unit) + + for plate in self._plates: + artist = plate.render(self._ctx) + minsize = get_min(minsize, artist) + + for name in self._nodes: + artist = self._nodes[name].render(self._ctx) + minsize = get_min(minsize, artist) + + self._ctx.reset_origin(minsize, self.shape is None) + + # Clear the figure from rendering context + self._ctx.reset_figure() + + for plate in self._plates: + plate.render(self._ctx) + + for edge in self._edges: + edge.render(self._ctx) + + for name in self._nodes: + self._nodes[name].render(self._ctx) + + return self.ax + + @property + def figure(self) -> plt.Figure: + """Figure as a property.""" + return self._ctx.figure() + + @property + def ax(self) -> plt.Axes: + """Axes as a property.""" + return self._ctx.ax() + + def show( + self, *args: Any, dpi: Optional[int] = None, **kwargs: Any + ) -> None: + """ + Wrapper on :class:`PGM.render()` that calls `matplotlib.show()` + immediately after. + + :param dpi: (optional) + The DPI value to use for rendering. + + """ + + self.render(dpi=dpi) + plt.show(*args, **kwargs) + + def savefig(self, fname: str, *args: Any, **kwargs: Any) -> None: + """ + Wrapper on ``matplotlib.Figure.savefig()`` that sets default image + padding using ``bbox_inchaes = tight``. + ``*args`` and ``**kwargs`` are passed to `matplotlib.Figure.savefig()`. + + :param fname: + The filename to save as. + + :param dpi: (optional) + The DPI value to use for saving. + + """ + kwargs["bbox_inches"] = kwargs.get("bbox_inches", "tight") + kwargs["dpi"] = kwargs.get("dpi", self._dpi) + self.figure.savefig(fname, *args, **kwargs) diff --git a/src/daft/plate.py b/src/daft/plate.py new file mode 100644 index 0000000..bf4e311 --- /dev/null +++ b/src/daft/plate.py @@ -0,0 +1,218 @@ +"""Daft errors""" + +__all__: list[str] = [] + + +import matplotlib as mpl +from matplotlib.patches import Rectangle + +import numpy as np + +from typing import Any, cast, Optional + +from .utils import _pop_multiple, RenderingContext +from .types import Tuple2F, Tuple4F, Position, RectParams + + +class Plate: + """ + A plate to encapsulate repeated independent processes in the model. + + :param rect: + The rectangle describing the plate bounds in model coordinates. + This is [x-start, y-start, x-length, y-length]. + + :param label: (optional) + A string to annotate the plate. + + :param label_offset: (optional) + The x- and y- offsets of the label text measured in points. + + :param shift: (optional) + The vertical "shift" of the plate measured in model units. This will + move the bottom of the panel by ``shift`` units. + + :param position: (optional) + One of ``"{vertical} {horizontal}"`` where vertical is ``"bottom"`` + or ``"middle"`` or ``"top"`` and horizontal is ``"left"`` or + ``"center"`` or ``"right"``. + + :param fontsize: (optional) + The fontsize to use. + + :param rect_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.Rectangle` constructor, which defines + the properties of the plate. + + :param bbox: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.axes.Axes.annotate` constructor, which defines + the box drawn around the text. + + """ + + def __init__( + self, + rect: Tuple4F, + label: Optional[str] = None, + label_offset: Tuple2F = (5, 5), + shift: float = 0, + position: Position = "bottom left", + fontsize: Optional[float] = None, + rect_params: Optional[RectParams] = None, + bbox: dict[str, Optional[Any]] = None, + ) -> None: + self.rect = rect + self.label = label + self.label_offset = label_offset + self.shift = shift + + if fontsize is not None: + self.fontsize: Optional[float] = fontsize + else: + self.fontsize = mpl.rcParams["font.size"] + + if rect_params is not None: + self.rect_params: Optional[RectParams] = rect_params + else: + self.rect_params = None + + if bbox is not None: + self.bbox: dict[str, Optional[Any]] = dict(bbox) + + # Set the awful default blue color to transparent + if "fc" not in self.bbox.keys(): + self.bbox["fc"] = "none" + else: + self.bbox = None + + self.position = position + + def render(self, ctx: RenderingContext) -> Rectangle: + """ + Render the plate in the given axes. + + :param ctx: + The :class:`_rendering_context` object. + + """ + ax = ctx.ax() + + shift = np.array([0, self.shift], dtype=np.float64) + rect = np.atleast_1d(np.asarray(self.rect, dtype=np.float64)) + bottom_left = np.asarray( + ctx.convert(*(rect[:2] + shift)), dtype=np.float64 + ) + top_right = np.asarray( + ctx.convert(*(rect[:2] + rect[2:])), dtype=np.float64 + ) + rect = np.concatenate([bottom_left, top_right - bottom_left]) + + if self.rect_params is not None: + rect_params = self.rect_params + else: + rect_params = {} + + rect_params["ec"] = _pop_multiple(rect_params, "k", "ec", "edgecolor") + rect_params["fc"] = _pop_multiple( + rect_params, ctx.plate_fc, "fc", "facecolor" + ) + rect_params["lw"] = _pop_multiple( + rect_params, ctx.line_width, "lw", "linewidth" + ) + + rectangle = Rectangle( + xy=(rect[0], rect[1]), width=rect[2], height=rect[3], **rect_params + ) + + ax.add_artist(rectangle) + + if self.label is not None: + offset = np.array(self.label_offset, dtype=np.float64) + if "left" in self.position: + position = rect[:2] + ha = "left" + elif "right" in self.position: + position = rect[:2] + position[0] += rect[2] + ha = "right" + offset[0] = -offset[0] + elif "center" in self.position: + position = rect[:2] + position[0] = rect[2] / 2 + rect[0] + ha = "center" + else: + raise RuntimeError( + f"Unknown positioning string: {self.position}" + ) + + if "bottom" in self.position: + va = "bottom" + elif "top" in self.position: + position[1] = rect[1] + rect[3] + offset[1] = -offset[1] - 0.1 + va = "top" + elif "middle" in self.position: + position[1] += rect[3] / 2 + va = "center" + else: + raise RuntimeError( + f"Unknown positioning string: {self.position}" + ) + + posx, posy = position + offx, offy = offset + + ax.annotate( + self.label, + xy=(posx, posy), + xycoords="data", + xytext=(offx, offy), + textcoords="offset points", + size=self.fontsize, + bbox=self.bbox, + horizontalalignment=ha, + verticalalignment=va, + ) + + return rectangle + + +class Text(Plate): + """ + A subclass of plate to writing text using grid coordinates. Any **kwargs + are passed through to :class:`PGM.Plate`. + + :param x: + The x-coordinate of the text in *model units*. + + :param y: + The y-coordinate of the text. + + :param label: + A string to write. + + :param fontsize: (optional) + The fontsize to use. + + """ + + def __init__( + self, x: float, y: float, label: str, fontsize: Optional[float] = None + ) -> None: + self.rect = (x, y, 0.0, 0.0) + self.label = label + self.fontsize = fontsize + self.label_offset = (0.0, 0.0) + self.rect_params = cast(RectParams, {"ec": "none"}) + self.bbox = {"fc": "none", "ec": "none"} + + super().__init__( + rect=self.rect, + label=self.label, + label_offset=self.label_offset, + fontsize=self.fontsize, + rect_params=self.rect_params, + bbox=self.bbox, + ) diff --git a/src/daft/types.py b/src/daft/types.py new file mode 100644 index 0000000..dd8f932 --- /dev/null +++ b/src/daft/types.py @@ -0,0 +1,81 @@ +"""Daft types""" + +__all__: list[str] = [] + +import numpy as np +from numpy.typing import NDArray +from typing import Any, Literal, TypeVar, TypedDict, Optional, Union + +T = TypeVar("T") + +NDArrayF = NDArray[np.float64] +NDArrayI = NDArray[np.int64] + +Tuple2 = tuple[T, T] +Tuple4 = tuple[T, T, T, T] + +Tuple2F = Tuple2[float] +Tuple4F = Tuple4[float] + +Shape = Literal["ellipse", "rectangle"] + +Position = Literal[ + "bottom left", + "bottom center", + "bottom right", + "middle left", + "middle center", + "middle right", + "top left", + "top center", + "top right", +] + + +class PlotParams(TypedDict): + lw: str + linewidth: str + ec: str + edgecolor: str + fc: str + facecolor: str + alpha: float + + +class LabelParams(TypedDict): + va: str + verticalalignment: str + ha: str + horizontalalignment: str + ma: str + + +class RectParams(TypedDict, total=False): + ec: str + edgecolor: str + fc: str + facecolor: str + lw: str + linewidth: str + + +class CTX_Kwargs(TypedDict): + shape: NDArrayF + origin: NDArrayF + grid_unit: float + node_unit: float + observed_style: str + alternate_style: str + line_width: float + node_ec: str + node_fc: str + plate_fc: str + directed: bool + aspect: float + label_params: Optional[LabelParams] + dpi: Optional[int] + + +AnyDict = Union[ + dict[str, Any], PlotParams, LabelParams, RectParams, CTX_Kwargs +] diff --git a/src/daft/_utils.py b/src/daft/utils.py similarity index 83% rename from src/daft/_utils.py rename to src/daft/utils.py index 7d41243..7bbe520 100644 --- a/src/daft/_utils.py +++ b/src/daft/utils.py @@ -5,8 +5,12 @@ import matplotlib.pyplot as plt import numpy as np +from typing import Any, cast, Optional -class _rendering_context: +from .types import NDArrayF, CTX_Kwargs, LabelParams, AnyDict + + +class RenderingContext: """ :param shape: The number of rows and columns in the grid. @@ -56,7 +60,7 @@ class _rendering_context: """ - def __init__(self, **kwargs): + def __init__(self, kwargs: CTX_Kwargs) -> None: # Save the style defaults. self.line_width = kwargs.get("line_width", 1.0) @@ -81,8 +85,8 @@ def __init__(self, **kwargs): self.padding = 0.1 self.shp_fig_scale = 2.54 - self.shape = np.array(kwargs.get("shape", [1, 1]), dtype=np.float64) - self.origin = np.array(kwargs.get("origin", [0, 0]), dtype=np.float64) + self.shape = kwargs.get("shape", np.atleast_1d((1, 1))) + self.origin = kwargs.get("origin", (0, 0)) self.grid_unit = kwargs.get("grid_unit", 2.0) self.figsize = self.grid_unit * self.shape / self.shp_fig_scale @@ -92,22 +96,24 @@ def __init__(self, **kwargs): self.plate_fc = kwargs.get("plate_fc", "w") self.directed = kwargs.get("directed", True) self.aspect = kwargs.get("aspect", 1.0) - self.label_params = dict(kwargs.get("label_params", {}) or {}) + self.label_params = cast( + LabelParams, kwargs.get("label_params", {}) or {} + ) self.dpi = kwargs.get("dpi", None) # Initialize the figure to ``None`` to handle caching later. - self._figure = None - self._ax = None + self._figure: Optional[plt.Figure] = None + self._ax: Optional[plt.Axis] = None - def reset_shape(self, shape, adj_origin=False): + def reset_shape(self, shape: NDArrayF, adj_origin: bool = False) -> None: """Reset the shape and figure size.""" # shape is scaled by grid_unit # so divide by grid_unit for proper shape self.shape = shape / self.grid_unit + self.padding self.figsize = self.grid_unit * self.shape / self.shp_fig_scale - def reset_origin(self, origin, adj_shape=False): + def reset_origin(self, origin: NDArrayF, adj_shape: bool = False) -> None: """Reset the origin.""" # origin is scaled by grid_unit # so divide by grid_unit for proper shape @@ -116,28 +122,28 @@ def reset_origin(self, origin, adj_shape=False): self.shape -= self.origin self.figsize = self.grid_unit * self.shape / self.shp_fig_scale - def reset_figure(self): + def reset_figure(self) -> None: """Reset the figure.""" self.close() - def close(self): + def close(self) -> None: """Close the figure if it is set up.""" if self._figure is not None: plt.close(self._figure) self._figure = None self._ax = None - def figure(self): + def figure(self) -> plt.Figure: """Return the current figure else create a new one.""" if self._figure is not None: return self._figure - args = {"figsize": self.figsize} + args: dict[str, Any] = {"figsize": self.figsize} if self.dpi is not None: args["dpi"] = self.dpi self._figure = plt.figure(**args) return self._figure - def ax(self): + def ax(self) -> plt.Axes: """Return the current axes else create a new one.""" if self._ax is not None: return self._ax @@ -155,19 +161,17 @@ def ax(self): return self._ax - def convert(self, *xy): + def convert(self, x: float, y: float) -> tuple[float, float]: """ Convert from model coordinates to plot coordinates. """ - if len(xy) != 2: - raise ValueError( - "You must provide two coordinates to `convert()`." - ) - return self.grid_unit * (np.atleast_1d(xy) - self.origin) + return self.grid_unit * (x - self.origin[0]), self.grid_unit * ( + y - self.origin[1] + ) -def _pop_multiple(_dict, default, *args): +def _pop_multiple(_dict: AnyDict, default: Any, *args: str) -> Any: """ A helper function for dealing with the way that matplotlib annoyingly allows multiple keyword arguments. For example, ``edgecolor`` and ``ec`` @@ -190,10 +194,10 @@ def _pop_multiple(_dict, default, *args): if len(args) == 0: raise ValueError("You must provide at least one argument to `pop()`.") - results = [] + results: list[Any] = [] for arg in args: try: - results.append((arg, _dict.pop(arg))) + results.append((arg, _dict.pop(arg))) # type: ignore[misc] except KeyError: pass diff --git a/test/baseline_images/test_examples/wet_grass.png b/test/baseline_images/test_examples/wet_grass.png new file mode 100644 index 0000000..2f55464 Binary files /dev/null and b/test/baseline_images/test_examples/wet_grass.png differ diff --git a/test/example.py b/test/example.py new file mode 100644 index 0000000..dad82f3 --- /dev/null +++ b/test/example.py @@ -0,0 +1,46 @@ +import daft +import matplotlib as mpl +import matplotlib.pyplot as plt + +# Colors. +no_circle = {"ec": "#fff"} +pgm = daft.PGM(grid_unit=1.5, node_unit=1, dpi=200) + +# x_offset, y_offset = -0.5, 0.5 # adjust the coordinates of the nodes +x_offset, y_offset = 0, 0 +pgm.add_node("z1", r"$z_1$", 1 + x_offset, 1 + y_offset) +pgm.add_node("z2", r"$z_2$", 2 + x_offset, 1 + y_offset) +pgm.add_node( + "z...", r"$\cdots$", 3 + x_offset, 1 + y_offset, plot_params=no_circle +) +pgm.add_node("zT", r"$z_T$", 4 + x_offset, 1 + y_offset) +pgm.add_node("y1", r"$x_1$", 1 + x_offset, 2.3 + y_offset) +pgm.add_node("y2", r"$x_2$", 2 + x_offset, 2.3 + y_offset) +pgm.add_node( + "y...", r"$\cdots$", 3 + x_offset, 2.3 + y_offset, plot_params=no_circle +) +pgm.add_node("yT", r"$x_T$", 4 + x_offset, 2.3 + y_offset) +pgm.add_node("x1", r"$x_1$", 1 + x_offset, y_offset) +pgm.add_node("x2", r"$x_2$", 2 + x_offset, y_offset) +pgm.add_node( + "x...", r"$\cdots$", 3 + x_offset, y_offset, plot_params=no_circle +) +pgm.add_node("xT", r"$x_T$", 4 + x_offset, y_offset) + +# Edges. +# pgm.add_edge("z1", "z2") +# pgm.add_edge("z2", "z...") +# pgm.add_edge("z...", "zT") +pgm.add_edge("x1", "z1") +pgm.add_edge("x2", "z2") +pgm.add_edge("xT", "zT") +pgm.add_edge("z1", "y2") +pgm.add_edge("z1", "yT") +pgm.add_edge("z2", "yT") + +# Render and save. +# pgm.render() +# plt.tight_layout() +# plt.show() +pgm.show() +# pgm.figure.savefig("../src/nar.png") diff --git a/test/test_daft.py b/test/test_daft.py index 2f31b02..be6e97c 100644 --- a/test/test_daft.py +++ b/test/test_daft.py @@ -63,7 +63,7 @@ def test_add_text(): pgm.add_text(x=0, y=0, label="text1") plate = pgm._plates[0] - assert plate.rect == [0, 0, 0, 0] + assert plate.rect == (0, 0, 0, 0) assert plate.label == "text1" diff --git a/test/test_examples.py b/test/test_examples.py index e4f5b02..4dc53e6 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -3,9 +3,11 @@ import daft from matplotlib.testing.decorators import image_comparison +from typing import Any + @image_comparison(baseline_images=["bca"], extensions=["png"]) -def test_bca(): +def test_bca() -> None: pgm = daft.PGM() pgm.add_node("a", r"$a$", 1, 5) pgm.add_node("b", r"$b$", 1, 4) @@ -17,7 +19,7 @@ def test_bca(): @image_comparison(baseline_images=["classic"], extensions=["png"]) -def test_classic(): +def test_classic() -> None: pgm = daft.PGM() # Hierarchical parameters. @@ -43,7 +45,7 @@ def test_classic(): @image_comparison(baseline_images=["deconvolution"], extensions=["png"]) -def test_deconvolution(): +def test_deconvolution() -> None: scale = 1.6 pgm = daft.PGM() @@ -151,7 +153,7 @@ def test_deconvolution(): @image_comparison(baseline_images=["exoplanets"], extensions=["png"]) -def test_exoplanets(): +def test_exoplanets() -> None: # Colors. p_color = {"ec": "#46a546"} s_color = {"ec": "#f89406"} @@ -190,7 +192,7 @@ def test_exoplanets(): @image_comparison(baseline_images=["fixed"], extensions=["png"]) -def test_fixed(): +def test_fixed() -> None: pgm = daft.PGM(aspect=1.5, node_unit=1.75) pgm.add_node("unobs", r"Unobserved!", 1, 4) pgm.add_node("obs", r"Observed!", 1, 3, observed=True) @@ -202,7 +204,7 @@ def test_fixed(): @image_comparison(baseline_images=["gaia"], extensions=["png"]) -def test_gaia(): +def test_gaia() -> None: pgm = daft.PGM() pgm.add_node("omega", r"$\omega$", 2, 5) pgm.add_node("true", r"$\tilde{X}_n$", 2, 4) @@ -220,7 +222,7 @@ def test_gaia(): @image_comparison(baseline_images=["galex"], extensions=["png"]) -def test_galex(): +def test_galex() -> None: pgm = daft.PGM() wide = 1.5 verywide = 1.5 * wide @@ -303,7 +305,7 @@ def test_galex(): @image_comparison(baseline_images=["huey_p_newton"], extensions=["png"]) -def test_huey_p_newton(): +def test_huey_p_newton() -> None: pgm = daft.PGM() kx, ky = 1.5, 1.0 @@ -338,7 +340,7 @@ def test_huey_p_newton(): @image_comparison(baseline_images=["logo"], extensions=["png"]) -def test_logo(): +def test_logo() -> None: pgm = daft.PGM() pgm.add_node("d", r"$D$", 0.5, 0.5) pgm.add_node("a", r"$a$", 1.5, 0.5, observed=True) @@ -351,7 +353,7 @@ def test_logo(): @image_comparison(baseline_images=["mrf"], extensions=["png"]) -def test_mrf(): +def test_mrf() -> None: pgm = daft.PGM(node_unit=0.4, grid_unit=1, directed=False) for i, (xi, yi) in enumerate(itertools.product(range(1, 5), range(1, 5))): @@ -381,7 +383,7 @@ def test_mrf(): @image_comparison(baseline_images=["no_circles"], extensions=["png"]) -def test_no_circles(): +def test_no_circles() -> None: pgm = daft.PGM(node_ec="none") pgm.add_node("cloudy", r"cloudy", 3, 3) pgm.add_node("rain", r"rain", 2, 2) @@ -395,7 +397,7 @@ def test_no_circles(): @image_comparison(baseline_images=["no_gray"], extensions=["png"]) -def test_no_gray(): +def test_no_gray() -> None: pgm = daft.PGM(observed_style="inner") # Hierarchical parameters. @@ -422,8 +424,8 @@ def test_no_gray(): @image_comparison(baseline_images=["recursive"], extensions=["png"]) -def test_recursive(): - def recurse(pgm, nodename, level, c): +def test_recursive() -> None: + def recurse(pgm: Any, nodename: Any, level: Any, c: Any) -> Any: if level > 4: return nodename r = c // 2 @@ -477,7 +479,7 @@ def recurse(pgm, nodename, level, c): @image_comparison(baseline_images=["thick_lines"], extensions=["png"]) -def test_thick_lines(): +def test_thick_lines() -> None: pgm = daft.PGM(line_width=2.5) # Hierarchical parameters. @@ -504,7 +506,7 @@ def test_thick_lines(): @image_comparison(baseline_images=["weaklensing"], extensions=["png"]) -def test_weaklensing(): +def test_weaklensing() -> None: pgm = daft.PGM() pgm.add_node("Omega", r"$\Omega$", -1, 4) pgm.add_node("gamma", r"$\gamma$", 0, 4) @@ -525,8 +527,8 @@ def test_weaklensing(): pgm.render() -@image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy(): +@image_comparison(baseline_images=["wet_grass"], extensions=["png"]) +def test_wet_grass() -> None: pgm = daft.PGM() pgm.add_node("cloudy", r"cloudy", 3, 3, aspect=1.8) pgm.add_node("rain", r"rain", 2, 2, aspect=1.2) @@ -552,7 +554,7 @@ def test_wordy(): @image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy(): +def test_wordy() -> None: pgm = daft.PGM() pgm.add_node("obs", r"$\epsilon^{obs}_n$", 2, 3, observed=True) pgm.add_node("true", r"$\epsilon^{true}_n$", 1, 3)