From d8d6313b058033e3f1a8a03f7801d7525e8ec816 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 12:22:01 -0500 Subject: [PATCH 01/15] refactor code to improve clarity, begin adding typing Signed-off-by: David Fulford --- pyproject.toml | 2 +- src/daft/__init__.py | 7 +- src/daft/_core.py | 754 ++-------------------------------------- src/daft/_exceptions.py | 5 +- src/daft/_types.py | 14 + src/daft/edge.py | 154 ++++++++ src/daft/node.py | 368 ++++++++++++++++++++ src/daft/plate.py | 202 +++++++++++ test/example.py | 40 +++ 9 files changed, 819 insertions(+), 727 deletions(-) create mode 100644 src/daft/_types.py create mode 100644 src/daft/edge.py create mode 100644 src/daft/node.py create mode 100644 src/daft/plate.py create mode 100644 test/example.py diff --git a/pyproject.toml b/pyproject.toml index b62981a..551f41e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] dependencies = [ "matplotlib", - "numpy", + "numpy>=1.20", ] dynamic = ["version"] diff --git a/src/daft/__init__.py b/src/daft/__init__.py index 0ae9017..6061c3f 100644 --- a/src/daft/__init__.py +++ b/src/daft/__init__.py @@ -2,13 +2,18 @@ from importlib.metadata import version as get_distribution +from .Plate import Plate, Text + from . import _core, _exceptions, _utils -from ._core import PGM, Node, Edge, Plate, Text +from ._core import PGM, Node, Edge from ._exceptions import SameLocationError from ._utils import _rendering_context, _pop_multiple __version__ = get_distribution("daft") __all__ = [] __all__ += _core.__all__ +__all__ += node.__all__ +__all__ += edge.__all__ __all__ += _exceptions.__all__ __all__ += _utils.__all__ +__all__ += _types.__all__ diff --git a/src/daft/_core.py b/src/daft/_core.py index 7a6516f..da4209f 100644 --- a/src/daft/_core.py +++ b/src/daft/_core.py @@ -1,18 +1,22 @@ """Code for Daft""" -__all__ = ["PGM", "Node", "Edge", "Plate"] +__all__ = ["PGM", ] # TODO: should Text be added? -import matplotlib as mpl import matplotlib.pyplot as plt -from matplotlib.patches import Ellipse -from matplotlib.patches import FancyArrow -from matplotlib.patches import Rectangle +from matplotlib.patches import Ellipse, FancyArrow, Rectangle import numpy as np -from ._exceptions import SameLocationError -from ._utils import _rendering_context, _pop_multiple +from typing import Any +from numpy.typing import NDArray, ArrayLike + + +from .node import Node +from .edge import Edge +from .plate import Plate, Text +from ._utils import _rendering_context +from ._types import NDArray2, NDArrayF, NDArrayI # pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines @@ -73,24 +77,24 @@ class PGM: def __init__( self, - shape=None, - origin=None, - grid_unit=2.0, - node_unit=1.0, - observed_style="shaded", - alternate_style="inner", - line_width=1.0, - node_ec="k", - node_fc="w", - plate_fc="w", - directed=True, - aspect=1.0, - label_params=None, - dpi=None, - ): - self._nodes = {} - self._edges = [] - self._plates = [] + shape: tuple[float, float] | list[float] | NDArrayF | None = None, + origin: ArrayLike | None = None, + grid_unit: float = 2.0, + node_unit: float = 1.0, + observed_style: str = "shaded", + alternate_style: str = "inner", + line_width: float = 1.0, + node_ec: str = "k", + node_fc: str = "w", + plate_fc: str = "w", + directed: bool = True, + aspect: float = 1.0, + label_params: dict[str, Any] | None = None, + dpi: int | None = None, + ) -> None: + self._nodes: dict[str, Node] = {} + self._edges: dict[str, Edge] = [] + self._plates: dict[str, Plate] = [] self._dpi = dpi # if shape and origin are not given, pass a default @@ -366,7 +370,7 @@ def add_text(self, x, y, label, fontsize=None): return None - def render(self, dpi=None): + def render(self, dpi: int | None = None) -> plt.Axes: """ Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in the model. This will create a new figure with the correct dimensions @@ -500,701 +504,3 @@ def savefig(self, fname, *args, **kwargs): if not self.figure: self.render() self.figure.savefig(fname, *args, **kwargs) - - -class Node: - """ - The representation of a random variable in a :class:`PGM`. - - :param name: - The plain-text identifier for the node. - - :param content: - The display form of the variable. - - :param x: - The x-coordinate of the node in *model units*. - - :param y: - The y-coordinate of the node. - - :param scale: (optional) - The diameter (or height) of the node measured in multiples of - ``node_unit`` as defined by the :class:`PGM` object. - - :param aspect: (optional) - The aspect ratio width/height for elliptical nodes; default 1. - - :param observed: (optional) - Should this be a conditioned variable? - - :param fixed: (optional) - Should this be a fixed (not permitted to vary) variable? - If `True`, modifies or over-rides ``diameter``, ``offset``, - ``facecolor``, and a few other ``plot_params`` settings. - This setting conflicts with ``observed``. - - :param alternate: (optional) - Should this use the alternate style? - - :param offset: (optional) - The ``(dx, dy)`` offset of the label (in points) from the default - centered position. - - :param fontsize: (optional) - The fontsize to use. - - :param plot_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.Ellipse` constructor. - - :param label_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.text.Annotation` constructor. Any kwargs not - used by Annontation get passed to :class:`matplotlib.text.Text`. - - :param shape: (optional) - String in {ellipse (default), rectangle} - If rectangle, aspect and scale holds for rectangle. - - """ - - def __init__( - self, - name, - content, - x, - y, - scale=1.0, - aspect=None, - observed=False, - fixed=False, - alternate=False, - offset=(0.0, 0.0), - fontsize=None, - plot_params=None, - label_params=None, - shape="ellipse", - ): - # Check Node style. - # Iterable is consumed, so first condition checks if two or more are - # true - node_style = iter((observed, alternate, fixed)) - if not ( - (any(node_style) and not any(node_style)) - or not any((observed, alternate, fixed)) - ): - msg = "A node cannot be more than one of `observed`, `fixed`, or `alternate`." - raise ValueError(msg) - - self.observed = observed - self.fixed = fixed - self.alternate = alternate - - # Metadata. - self.name = name - self.content = content - - # Coordinates and dimensions. - self.x, self.y = float(x), float(y) - self.scale = float(scale) - if self.fixed: - self.scale /= 6.0 - if aspect is not None: - self.aspect = float(aspect) - else: - self.aspect = aspect - - # Set fontsize - self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"] - - # Display parameters. - self.plot_params = dict(plot_params) if plot_params else {} - - # Text parameters. - self.offset = list(offset) - self.label_params = dict(label_params) if label_params else None - - # Shape - if shape in ["ellipse", "rectangle"]: - self.shape = shape - else: - print("Warning: wrong shape value, set to ellipse instead") - self.shape = "ellipse" - - def render(self, ctx): - """ - Render the node. - - :param ctx: - The :class:`_rendering_context` object. - - """ - # Get the axes and default plotting parameters from the rendering - # context. - ax = ctx.ax() - - # Resolve the plotting parameters. - plot_params = dict(self.plot_params) - - plot_params["lw"] = _pop_multiple( - plot_params, ctx.line_width, "lw", "linewidth" - ) - - plot_params["ec"] = plot_params["edgecolor"] = _pop_multiple( - plot_params, ctx.node_ec, "ec", "edgecolor" - ) - - fc_is_set = "fc" in plot_params or "facecolor" in plot_params - plot_params["fc"] = _pop_multiple( - plot_params, ctx.node_fc, "fc", "facecolor" - ) - fc = plot_params["fc"] - - plot_params["alpha"] = plot_params.get("alpha", 1) - - # And the label parameters. - if self.label_params is None: - label_params = dict(ctx.label_params) - else: - label_params = dict(self.label_params) - - label_params["va"] = _pop_multiple( - label_params, "center", "va", "verticalalignment" - ) - - label_params["ha"] = _pop_multiple( - label_params, "center", "ha", "horizontalalignment" - ) - - # Deal with ``fixed`` nodes. - scale = self.scale - if self.fixed: - # MAGIC: These magic numbers should depend on the grid/node units. - self.offset[1] += 6 - - label_params["va"] = "baseline" - label_params.pop("verticalalignment", None) - label_params.pop("ma", None) - - if not fc_is_set: - plot_params["fc"] = "k" - - diameter = ctx.node_unit * scale - if self.aspect is not None: - aspect = self.aspect - else: - aspect = ctx.aspect - - # Set up an observed node or alternate node. Note the fc INSANITY. - if self.observed and not self.fixed: - style = ctx.observed_style - elif self.alternate and not self.fixed: - style = ctx.alternate_style - else: - style = False - - if style: - # Update the plotting parameters depending on the style of - # observed node. - h = float(diameter) - w = aspect * float(diameter) - if style == "shaded": - plot_params["fc"] = "0.7" - elif style == "outer": - h = diameter + 0.1 * diameter - w = aspect * diameter + 0.1 * diameter - elif style == "inner": - h = diameter - 0.1 * diameter - w = aspect * diameter - 0.1 * diameter - plot_params["fc"] = fc - - # Draw the background ellipse. - if self.shape == "ellipse": - bg = Ellipse( - xy=ctx.convert(self.x, self.y), - width=w, - height=h, - **plot_params, - ) - elif self.shape == "rectangle": - # Adapt to make Rectangle the same api than ellipse - wi = w - xy = ctx.convert(self.x, self.y) - xy[0] = xy[0] - wi / 2.0 - xy[1] = xy[1] - h / 2.0 - - bg = Rectangle(xy=xy, width=wi, height=h, **plot_params) - else: - # Should never append - raise ( - ValueError( - "Wrong shape in object causes an error in render" - ) - ) - - ax.add_artist(bg) - - # Reset the face color. - plot_params["fc"] = fc - - # Draw the foreground ellipse. - if not fc_is_set and not self.fixed and self.observed: - plot_params["fc"] = "none" - - if self.shape == "ellipse": - el = Ellipse( - xy=ctx.convert(self.x, self.y), - width=diameter * aspect, - height=diameter, - **plot_params, - ) - elif self.shape == "rectangle": - # Adapt to make Rectangle the same api than ellipse - wi = diameter * aspect - xy = ctx.convert(self.x, self.y) - xy[0] = xy[0] - wi / 2.0 - xy[1] = xy[1] - diameter / 2.0 - - el = Rectangle(xy=xy, width=wi, height=diameter, **plot_params) - else: - # Should never append - raise ( - ValueError("Wrong shape in object causes an error in render") - ) - - ax.add_artist(el) - - # Reset the face color. - plot_params["fc"] = fc - - # Annotate the node. - ax.annotate( - self.content, - ctx.convert(self.x, self.y), - xycoords="data", - xytext=self.offset, - textcoords="offset points", - size=self.fontsize, - **label_params, - ) - - return el - - def get_frontier_coord(self, target_xy, ctx, edge): - """ - Get the coordinates of the point of intersection between the - shape of the node and a line starting from the center of the node to an - arbitrary point. Will throw a :class:`SameLocationError` if the nodes - contain the same `x` and `y` coordinates. See the example of rectangle - below: - - .. code-block:: python - - _____________ - | | ____--X (target_node) - | __--X---- - | X-- |(return coordinate of this point) - | | - |____________| - - :target_xy: (x float, y float) - A tuple of coordinate of target node - - """ - - # Scale the coordinates appropriately. - x1, y1 = ctx.convert(self.x, self.y) - x2, y2 = target_xy[0], target_xy[1] - - # Aspect ratios. - if self.aspect is not None: - aspect = self.aspect - else: - aspect = ctx.aspect - - if self.shape == "ellipse": - # Compute the distances. - dx, dy = x2 - x1, y2 - y1 - if dx == 0.0 and dy == 0.0: - raise SameLocationError(edge) - dist1 = np.sqrt(dy * dy + dx * dx / float(aspect * aspect)) - - # Compute the fractional effect of the radii of the nodes. - alpha1 = 0.5 * ctx.node_unit * self.scale / dist1 - - # Get the coordinates of the starting position. - x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy - - return x0, y0 - - elif self.shape == "rectangle": - dx, dy = x2 - x1, y2 - y1 - - # theta = np.angle(complex(dx, dy)) - # print(theta) - # left or right intersection - dxx1 = self.scale * aspect / 2.0 * (np.sign(dx) or 1.0) - dyy1 = ( - self.scale - * aspect - / 2.0 - * np.abs(dy / dx) - * (np.sign(dy) or 1.0) - ) - val1 = np.abs(complex(dxx1, dyy1)) - - # up or bottom intersection - dxx2 = self.scale * 0.5 * np.abs(dx / dy) * (np.sign(dx) or 1.0) - dyy2 = self.scale * 0.5 * (np.sign(dy) or 1.0) - val2 = np.abs(complex(dxx2, dyy2)) - - if val1 < val2: - return x1 + dxx1, y1 + dyy1 - else: - return x1 + dxx2, y1 + dyy2 - - else: - # Should never append - raise ValueError("Wrong shape in object causes an error") - - -class Edge: - """ - An edge between two :class:`Node` objects. - - :param node1: - The first :class:`Node`. - - :param node2: - The second :class:`Node`. The arrow will point towards this node. - - :param directed: (optional) - Should the edge be directed from ``node1`` to ``node2``? In other - words: should it have an arrow? - - :param label: (optional) - A string to annotate the edge. - - :param xoffset: (optional) - The x-offset from the middle of the arrow to plot the label. - Only takes effect if `label` is defined in `plot_params`. - - :param yoffset: (optional) - The y-offset from the middle of the arrow to plot the label. - Only takes effect if `label` is defined in `plot_params`. - - :param plot_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.FancyArrow` constructor to adjust - edge behavior. - - :param label_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.axes.Axes.annotate` constructor to adjust - label behavior. - - """ - - def __init__( - self, - node1, - node2, - directed=True, - label=None, - xoffset=0, - yoffset=0.1, - plot_params=None, - label_params=None, - ): - self.node1 = node1 - self.node2 = node2 - self.directed = directed - self.label = label - self.xoffset = xoffset - self.yoffset = yoffset - self.plot_params = dict(plot_params) if plot_params else {} - self.label_params = dict(label_params) if label_params else {} - - def _get_coords(self, ctx): - """ - Get the coordinates of the line. - - :param conv: - A callable coordinate conversion. - - :returns: - * ``x0``, ``y0``: the coordinates of the start of the line. - * ``dx0``, ``dy0``: the displacement vector. - - """ - # Scale the coordinates appropriately. - x1, y1 = ctx.convert(self.node1.x, self.node1.y) - x2, y2 = ctx.convert(self.node2.x, self.node2.y) - - x3, y3 = self.node1.get_frontier_coord((x2, y2), ctx, self) - x4, y4 = self.node2.get_frontier_coord((x1, y1), ctx, self) - - return x3, y3, x4 - x3, y4 - y3 - - def render(self, ctx): - """ - Render the edge in the given axes. - - :param ctx: - The :class:`_rendering_context` object. - - """ - ax = ctx.ax() - - plot_params = self.plot_params - plot_params["linewidth"] = _pop_multiple( - plot_params, ctx.line_width, "lw", "linewidth" - ) - - plot_params["linestyle"] = plot_params.get("linestyle", "-") - - # Add edge annotation. - if self.label is not None: - x, y, dx, dy = self._get_coords(ctx) - ax.annotate( - self.label, - [x + 0.5 * dx + self.xoffset, y + 0.5 * dy + self.yoffset], - xycoords="data", - xytext=[0, 3], - textcoords="offset points", - ha="center", - va="center", - **self.label_params, - ) - - if self.directed: - plot_params["ec"] = _pop_multiple( - plot_params, "k", "ec", "edgecolor" - ) - plot_params["fc"] = _pop_multiple( - plot_params, "k", "fc", "facecolor" - ) - plot_params["head_length"] = plot_params.get("head_length", 0.25) - plot_params["head_width"] = plot_params.get("head_width", 0.1) - - # Build an arrow. - args = self._get_coords(ctx) - - # zero lengh arrow produce error - if not (args[2] == 0.0 and args[3] == 0.0): - ar = FancyArrow( - *self._get_coords(ctx), - width=0, - length_includes_head=True, - **plot_params, - ) - - # Add the arrow to the axes. - ax.add_artist(ar) - return ar - - else: - print(args[2], args[3]) - - else: - plot_params["color"] = plot_params.get("color", "k") - - # Get the right coordinates. - x, y, dx, dy = self._get_coords(ctx) - - # Plot the line. - line = ax.plot([x, x + dx], [y, y + dy], **plot_params) - return line - - -class Plate: - """ - A plate to encapsulate repeated independent processes in the model. - - :param rect: - The rectangle describing the plate bounds in model coordinates. - This is [x-start, y-start, x-length, y-length]. - - :param label: (optional) - A string to annotate the plate. - - :param label_offset: (optional) - The x- and y- offsets of the label text measured in points. - - :param shift: (optional) - The vertical "shift" of the plate measured in model units. This will - move the bottom of the panel by ``shift`` units. - - :param position: (optional) - One of ``"{vertical} {horizontal}"`` where vertical is ``"bottom"`` - or ``"middle"`` or ``"top"`` and horizontal is ``"left"`` or - ``"center"`` or ``"right"``. - - :param fontsize: (optional) - The fontsize to use. - - :param rect_params: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.patches.Rectangle` constructor, which defines - the properties of the plate. - - :param bbox: (optional) - A dictionary of parameters to pass to the - :class:`matplotlib.axes.Axes.annotate` constructor, which defines - the box drawn around the text. - - """ - - def __init__( - self, - rect, - label=None, - label_offset=(5, 5), - shift=0, - position="bottom left", - fontsize=None, - rect_params=None, - bbox=None, - ): - self.rect = rect - self.label = label - self.label_offset = label_offset - self.shift = shift - - if fontsize is not None: - self.fontsize = fontsize - else: - self.fontsize = mpl.rcParams["font.size"] - - if rect_params is not None: - self.rect_params = dict(rect_params) - else: - self.rect_params = None - - if bbox is not None: - self.bbox = dict(bbox) - - # Set the awful default blue color to transparent - if "fc" not in self.bbox.keys(): - self.bbox["fc"] = "none" - else: - self.bbox = None - - self.position = position - - def render(self, ctx): - """ - Render the plate in the given axes. - - :param ctx: - The :class:`_rendering_context` object. - - """ - ax = ctx.ax() - - shift = np.array([0, self.shift], dtype=np.float64) - rect = np.atleast_1d(self.rect) - bottom_left = ctx.convert(*(rect[:2] + shift)) - top_right = ctx.convert(*(rect[:2] + rect[2:])) - rect = np.concatenate([bottom_left, top_right - bottom_left]) - - if self.rect_params is not None: - rect_params = self.rect_params - else: - rect_params = {} - - rect_params["ec"] = _pop_multiple(rect_params, "k", "ec", "edgecolor") - rect_params["fc"] = _pop_multiple( - rect_params, ctx.plate_fc, "fc", "facecolor" - ) - rect_params["lw"] = _pop_multiple( - rect_params, ctx.line_width, "lw", "linewidth" - ) - rectangle = Rectangle(rect[:2], *rect[2:], **rect_params) - - ax.add_artist(rectangle) - - if self.label is not None: - offset = np.array(self.label_offset, dtype=np.float64) - if "left" in self.position: - position = rect[:2] - ha = "left" - elif "right" in self.position: - position = rect[:2] - position[0] += rect[2] - ha = "right" - offset[0] = -offset[0] - elif "center" in self.position: - position = rect[:2] - position[0] = rect[2] / 2 + rect[0] - ha = "center" - else: - raise RuntimeError( - f"Unknown positioning string: {self.position}" - ) - - if "bottom" in self.position: - va = "bottom" - elif "top" in self.position: - position[1] = rect[1] + rect[3] - offset[1] = -offset[1] - 0.1 - va = "top" - elif "middle" in self.position: - position[1] += rect[3] / 2 - va = "center" - else: - raise RuntimeError( - f"Unknown positioning string: {self.position}" - ) - - ax.annotate( - self.label, - xy=position, - xycoords="data", - xytext=offset, - textcoords="offset points", - size=self.fontsize, - bbox=self.bbox, - horizontalalignment=ha, - verticalalignment=va, - ) - - return rectangle - - -class Text(Plate): - """ - A subclass of plate to writing text using grid coordinates. Any **kwargs - are passed through to :class:`PGM.Plate`. - - :param x: - The x-coordinate of the text in *model units*. - - :param y: - The y-coordinate of the text. - - :param label: - A string to write. - - :param fontsize: (optional) - The fontsize to use. - - """ - - def __init__(self, x, y, label, fontsize=None): - self.rect = [x, y, 0.0, 0.0] - self.label = label - self.fontsize = fontsize - self.label_offset = [0.0, 0.0] - self.bbox = {"fc": "none", "ec": "none"} - self.rect_params = {"ec": "none"} - - super().__init__( - rect=self.rect, - label=self.label, - label_offset=self.label_offset, - fontsize=self.fontsize, - rect_params=self.rect_params, - bbox=self.bbox, - ) diff --git a/src/daft/_exceptions.py b/src/daft/_exceptions.py index 605b7f4..dc40c8a 100644 --- a/src/daft/_exceptions.py +++ b/src/daft/_exceptions.py @@ -11,9 +11,12 @@ class SameLocationError(Exception): The Edge object whose nodes are being added. """ - def __init__(self, edge): + def __init__(self, edge: 'Edge') -> None: self.message = ( "Attempted to add edge between `{}` and `{}` but they " + "share the same location." ).format(edge.node1.name, edge.node2.name) super().__init__(self.message) + + +from ._core import Edge diff --git a/src/daft/_types.py b/src/daft/_types.py new file mode 100644 index 0000000..2ff1d20 --- /dev/null +++ b/src/daft/_types.py @@ -0,0 +1,14 @@ +"""Daft types""" + +__all__: list[str] = [] + +import numpy as np +from numpy.typing import NDArray, ArrayLike +from typing import Any, Annotated, Literal, TypeVar + +DType = TypeVar("DType", bound=np.generic) + +NDArray2 = Annotated[NDArray[DType], Literal[2]] + +NDArrayF = NDArray[np.float64] +NDArrayI = NDArray[np.int64] diff --git a/src/daft/edge.py b/src/daft/edge.py new file mode 100644 index 0000000..5955375 --- /dev/null +++ b/src/daft/edge.py @@ -0,0 +1,154 @@ +"""Edge""" + +__all__ = ["Edge"] + +from ._utils import _pop_multiple + + +class Edge: + """ + An edge between two :class:`Node` objects. + + :param node1: + The first :class:`Node`. + + :param node2: + The second :class:`Node`. The arrow will point towards this node. + + :param directed: (optional) + Should the edge be directed from ``node1`` to ``node2``? In other + words: should it have an arrow? + + :param label: (optional) + A string to annotate the edge. + + :param xoffset: (optional) + The x-offset from the middle of the arrow to plot the label. + Only takes effect if `label` is defined in `plot_params`. + + :param yoffset: (optional) + The y-offset from the middle of the arrow to plot the label. + Only takes effect if `label` is defined in `plot_params`. + + :param plot_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.FancyArrow` constructor to adjust + edge behavior. + + :param label_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.axes.Axes.annotate` constructor to adjust + label behavior. + + """ + + def __init__( + self, + node1, + node2, + directed=True, + label=None, + xoffset=0, + yoffset=0.1, + plot_params=None, + label_params=None, + ): + self.node1 = node1 + self.node2 = node2 + self.directed = directed + self.label = label + self.xoffset = xoffset + self.yoffset = yoffset + self.plot_params = dict(plot_params) if plot_params else {} + self.label_params = dict(label_params) if label_params else {} + + def _get_coords(self, ctx): + """ + Get the coordinates of the line. + + :param conv: + A callable coordinate conversion. + + :returns: + * ``x0``, ``y0``: the coordinates of the start of the line. + * ``dx0``, ``dy0``: the displacement vector. + + """ + # Scale the coordinates appropriately. + x1, y1 = ctx.convert(self.node1.x, self.node1.y) + x2, y2 = ctx.convert(self.node2.x, self.node2.y) + + x3, y3 = self.node1.get_frontier_coord((x2, y2), ctx, self) + x4, y4 = self.node2.get_frontier_coord((x1, y1), ctx, self) + + return x3, y3, x4 - x3, y4 - y3 + + def render(self, ctx): + """ + Render the edge in the given axes. + + :param ctx: + The :class:`_rendering_context` object. + + """ + ax = ctx.ax() + + plot_params = self.plot_params + plot_params["linewidth"] = _pop_multiple( + plot_params, ctx.line_width, "lw", "linewidth" + ) + + plot_params["linestyle"] = plot_params.get("linestyle", "-") + + # Add edge annotation. + if self.label is not None: + x, y, dx, dy = self._get_coords(ctx) + ax.annotate( + self.label, + [x + 0.5 * dx + self.xoffset, y + 0.5 * dy + self.yoffset], + xycoords="data", + xytext=[0, 3], + textcoords="offset points", + ha="center", + va="center", + **self.label_params, + ) + + if self.directed: + plot_params["ec"] = _pop_multiple( + plot_params, "k", "ec", "edgecolor" + ) + plot_params["fc"] = _pop_multiple( + plot_params, "k", "fc", "facecolor" + ) + plot_params["head_length"] = plot_params.get("head_length", 0.25) + plot_params["head_width"] = plot_params.get("head_width", 0.1) + + # Build an arrow. + args = self._get_coords(ctx) + + # zero lengh arrow produce error + if not (args[2] == 0.0 and args[3] == 0.0): + ar = FancyArrow( + *self._get_coords(ctx), + width=0, + length_includes_head=True, + **plot_params, + ) + + # Add the arrow to the axes. + ax.add_artist(ar) + return ar + + else: + print(args[2], args[3]) + + else: + plot_params["color"] = plot_params.get("color", "k") + + # Get the right coordinates. + x, y, dx, dy = self._get_coords(ctx) + + # Plot the line. + line = ax.plot([x, x + dx], [y, y + dy], **plot_params) + return line diff --git a/src/daft/node.py b/src/daft/node.py new file mode 100644 index 0000000..d30d9ff --- /dev/null +++ b/src/daft/node.py @@ -0,0 +1,368 @@ +"""Node""" + +__all__ = ["Node"] + +import matplotlib as mpl +from matplotlib.patches import Ellipse, Rectangle + +import numpy as np + +from ._utils import _rendering_context, _pop_multiple +from ._exceptions import SameLocationError + + +class Node: + """ + The representation of a random variable in a :class:`PGM`. + + :param name: + The plain-text identifier for the node. + + :param content: + The display form of the variable. + + :param x: + The x-coordinate of the node in *model units*. + + :param y: + The y-coordinate of the node. + + :param scale: (optional) + The diameter (or height) of the node measured in multiples of + ``node_unit`` as defined by the :class:`PGM` object. + + :param aspect: (optional) + The aspect ratio width/height for elliptical nodes; default 1. + + :param observed: (optional) + Should this be a conditioned variable? + + :param fixed: (optional) + Should this be a fixed (not permitted to vary) variable? + If `True`, modifies or over-rides ``diameter``, ``offset``, + ``facecolor``, and a few other ``plot_params`` settings. + This setting conflicts with ``observed``. + + :param alternate: (optional) + Should this use the alternate style? + + :param offset: (optional) + The ``(dx, dy)`` offset of the label (in points) from the default + centered position. + + :param fontsize: (optional) + The fontsize to use. + + :param plot_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.Ellipse` constructor. + + :param label_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.text.Annotation` constructor. Any kwargs not + used by Annontation get passed to :class:`matplotlib.text.Text`. + + :param shape: (optional) + String in {ellipse (default), rectangle} + If rectangle, aspect and scale holds for rectangle. + + """ + + def __init__( + self, + name, + content, + x, + y, + scale=1.0, + aspect=None, + observed=False, + fixed=False, + alternate=False, + offset=(0.0, 0.0), + fontsize=None, + plot_params=None, + label_params=None, + shape="ellipse", + ): + # Check Node style. + # Iterable is consumed, so first condition checks if two or more are + # true + node_style = iter((observed, alternate, fixed)) + if not ( + (any(node_style) and not any(node_style)) + or not any((observed, alternate, fixed)) + ): + msg = "A node cannot be more than one of `observed`, `fixed`, or `alternate`." + raise ValueError(msg) + + self.observed = observed + self.fixed = fixed + self.alternate = alternate + + # Metadata. + self.name = name + self.content = content + + # Coordinates and dimensions. + self.x, self.y = float(x), float(y) + self.scale = float(scale) + if self.fixed: + self.scale /= 6.0 + if aspect is not None: + self.aspect = float(aspect) + else: + self.aspect = aspect + + # Set fontsize + self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"] + + # Display parameters. + self.plot_params = dict(plot_params) if plot_params else {} + + # Text parameters. + self.offset = list(offset) + self.label_params = dict(label_params) if label_params else None + + # Shape + if shape in ["ellipse", "rectangle"]: + self.shape = shape + else: + print("Warning: wrong shape value, set to ellipse instead") + self.shape = "ellipse" + + def render(self, ctx) -> Ellipse | Rectangle: + """ + Render the node. + + :param ctx: + The :class:`_rendering_context` object. + + """ + # Get the axes and default plotting parameters from the rendering + # context. + ax = ctx.ax() + + # Resolve the plotting parameters. + plot_params = dict(self.plot_params) + + plot_params["lw"] = _pop_multiple( + plot_params, ctx.line_width, "lw", "linewidth" + ) + + plot_params["ec"] = plot_params["edgecolor"] = _pop_multiple( + plot_params, ctx.node_ec, "ec", "edgecolor" + ) + + fc_is_set = "fc" in plot_params or "facecolor" in plot_params + plot_params["fc"] = _pop_multiple( + plot_params, ctx.node_fc, "fc", "facecolor" + ) + fc = plot_params["fc"] + + plot_params["alpha"] = plot_params.get("alpha", 1) + + # And the label parameters. + if self.label_params is None: + label_params = dict(ctx.label_params) + else: + label_params = dict(self.label_params) + + label_params["va"] = _pop_multiple( + label_params, "center", "va", "verticalalignment" + ) + + label_params["ha"] = _pop_multiple( + label_params, "center", "ha", "horizontalalignment" + ) + + # Deal with ``fixed`` nodes. + scale = self.scale + if self.fixed: + # MAGIC: These magic numbers should depend on the grid/node units. + self.offset[1] += 6 + + label_params["va"] = "baseline" + label_params.pop("verticalalignment", None) + label_params.pop("ma", None) + + if not fc_is_set: + plot_params["fc"] = "k" + + diameter = ctx.node_unit * scale + if self.aspect is not None: + aspect = self.aspect + else: + aspect = ctx.aspect + + # Set up an observed node or alternate node. Note the fc INSANITY. + if self.observed and not self.fixed: + style = ctx.observed_style + elif self.alternate and not self.fixed: + style = ctx.alternate_style + else: + style = False + + if style: + # Update the plotting parameters depending on the style of + # observed node. + h = float(diameter) + w = aspect * float(diameter) + if style == "shaded": + plot_params["fc"] = "0.7" + elif style == "outer": + h = diameter + 0.1 * diameter + w = aspect * diameter + 0.1 * diameter + elif style == "inner": + h = diameter - 0.1 * diameter + w = aspect * diameter - 0.1 * diameter + plot_params["fc"] = fc + + # Draw the background ellipse. + if self.shape == "ellipse": + bg = Ellipse( + xy=ctx.convert(self.x, self.y), + width=w, + height=h, + **plot_params, + ) + elif self.shape == "rectangle": + # Adapt to make Rectangle the same api than ellipse + wi = w + xy = ctx.convert(self.x, self.y) + xy[0] = xy[0] - wi / 2.0 + xy[1] = xy[1] - h / 2.0 + + bg = Rectangle(xy=xy, width=wi, height=h, **plot_params) + else: + # Should never append + raise ( + ValueError( + "Wrong shape in object causes an error in render" + ) + ) + + ax.add_artist(bg) + + # Reset the face color. + plot_params["fc"] = fc + + # Draw the foreground ellipse. + if not fc_is_set and not self.fixed and self.observed: + plot_params["fc"] = "none" + + if self.shape == "ellipse": + el = Ellipse( + xy=ctx.convert(self.x, self.y), + width=diameter * aspect, + height=diameter, + **plot_params, + ) + elif self.shape == "rectangle": + # Adapt to make Rectangle the same api than ellipse + wi = diameter * aspect + xy = ctx.convert(self.x, self.y) + xy[0] = xy[0] - wi / 2.0 + xy[1] = xy[1] - diameter / 2.0 + + el = Rectangle(xy=xy, width=wi, height=diameter, **plot_params) + else: + # Should never append + raise ( + ValueError("Wrong shape in object causes an error in render") + ) + + ax.add_artist(el) + + # Reset the face color. + plot_params["fc"] = fc + + # Annotate the node. + ax.annotate( + self.content, + ctx.convert(self.x, self.y), + xycoords="data", + xytext=self.offset, + textcoords="offset points", + size=self.fontsize, + **label_params, + ) + + return el + + def get_frontier_coord(self, target_xy, ctx, edge): + """ + Get the coordinates of the point of intersection between the + shape of the node and a line starting from the center of the node to an + arbitrary point. Will throw a :class:`SameLocationError` if the nodes + contain the same `x` and `y` coordinates. See the example of rectangle + below: + + .. code-block:: python + + _____________ + | | ____--X (target_node) + | __--X---- + | X-- |(return coordinate of this point) + | | + |____________| + + :target_xy: (x float, y float) + A tuple of coordinate of target node + + """ + + # Scale the coordinates appropriately. + x1, y1 = ctx.convert(self.x, self.y) + x2, y2 = target_xy[0], target_xy[1] + + # Aspect ratios. + if self.aspect is not None: + aspect = self.aspect + else: + aspect = ctx.aspect + + if self.shape == "ellipse": + # Compute the distances. + dx, dy = x2 - x1, y2 - y1 + if dx == 0.0 and dy == 0.0: + raise SameLocationError(edge) + dist1 = np.sqrt(dy * dy + dx * dx / float(aspect * aspect)) + + # Compute the fractional effect of the radii of the nodes. + alpha1 = 0.5 * ctx.node_unit * self.scale / dist1 + + # Get the coordinates of the starting position. + x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy + + return x0, y0 + + elif self.shape == "rectangle": + dx, dy = x2 - x1, y2 - y1 + + # theta = np.angle(complex(dx, dy)) + # print(theta) + # left or right intersection + dxx1 = self.scale * aspect / 2.0 * (np.sign(dx) or 1.0) + dyy1 = ( + self.scale + * aspect + / 2.0 + * np.abs(dy / dx) + * (np.sign(dy) or 1.0) + ) + val1 = np.abs(complex(dxx1, dyy1)) + + # up or bottom intersection + dxx2 = self.scale * 0.5 * np.abs(dx / dy) * (np.sign(dx) or 1.0) + dyy2 = self.scale * 0.5 * (np.sign(dy) or 1.0) + val2 = np.abs(complex(dxx2, dyy2)) + + if val1 < val2: + return x1 + dxx1, y1 + dyy1 + else: + return x1 + dxx2, y1 + dyy2 + + else: + # Should never append + raise ValueError("Wrong shape in object causes an error") diff --git a/src/daft/plate.py b/src/daft/plate.py new file mode 100644 index 0000000..b76bab3 --- /dev/null +++ b/src/daft/plate.py @@ -0,0 +1,202 @@ +import matplotlib as mpl +from matplotlib.patches import Rectangle + +import numpy as np + +from ._utils import _pop_multiple + +# Move exception import to end of file to resolve circular dependency +from ._exceptions import SameLocationError + + + +class Plate: + """ + A plate to encapsulate repeated independent processes in the model. + + :param rect: + The rectangle describing the plate bounds in model coordinates. + This is [x-start, y-start, x-length, y-length]. + + :param label: (optional) + A string to annotate the plate. + + :param label_offset: (optional) + The x- and y- offsets of the label text measured in points. + + :param shift: (optional) + The vertical "shift" of the plate measured in model units. This will + move the bottom of the panel by ``shift`` units. + + :param position: (optional) + One of ``"{vertical} {horizontal}"`` where vertical is ``"bottom"`` + or ``"middle"`` or ``"top"`` and horizontal is ``"left"`` or + ``"center"`` or ``"right"``. + + :param fontsize: (optional) + The fontsize to use. + + :param rect_params: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.patches.Rectangle` constructor, which defines + the properties of the plate. + + :param bbox: (optional) + A dictionary of parameters to pass to the + :class:`matplotlib.axes.Axes.annotate` constructor, which defines + the box drawn around the text. + + """ + + def __init__( + self, + rect, + label=None, + label_offset=(5, 5), + shift=0, + position="bottom left", + fontsize=None, + rect_params=None, + bbox=None, + ): + self.rect = rect + self.label = label + self.label_offset = label_offset + self.shift = shift + + if fontsize is not None: + self.fontsize = fontsize + else: + self.fontsize = mpl.rcParams["font.size"] + + if rect_params is not None: + self.rect_params = dict(rect_params) + else: + self.rect_params = None + + if bbox is not None: + self.bbox = dict(bbox) + + # Set the awful default blue color to transparent + if "fc" not in self.bbox.keys(): + self.bbox["fc"] = "none" + else: + self.bbox = None + + self.position = position + + def render(self, ctx): + """ + Render the plate in the given axes. + + :param ctx: + The :class:`_rendering_context` object. + + """ + ax = ctx.ax() + + shift = np.array([0, self.shift], dtype=np.float64) + rect = np.atleast_1d(self.rect) + bottom_left = ctx.convert(*(rect[:2] + shift)) + top_right = ctx.convert(*(rect[:2] + rect[2:])) + rect = np.concatenate([bottom_left, top_right - bottom_left]) + + if self.rect_params is not None: + rect_params = self.rect_params + else: + rect_params = {} + + rect_params["ec"] = _pop_multiple(rect_params, "k", "ec", "edgecolor") + rect_params["fc"] = _pop_multiple( + rect_params, ctx.plate_fc, "fc", "facecolor" + ) + rect_params["lw"] = _pop_multiple( + rect_params, ctx.line_width, "lw", "linewidth" + ) + rectangle = Rectangle(rect[:2], *rect[2:], **rect_params) + + ax.add_artist(rectangle) + + if self.label is not None: + offset = np.array(self.label_offset, dtype=np.float64) + if "left" in self.position: + position = rect[:2] + ha = "left" + elif "right" in self.position: + position = rect[:2] + position[0] += rect[2] + ha = "right" + offset[0] = -offset[0] + elif "center" in self.position: + position = rect[:2] + position[0] = rect[2] / 2 + rect[0] + ha = "center" + else: + raise RuntimeError( + f"Unknown positioning string: {self.position}" + ) + + if "bottom" in self.position: + va = "bottom" + elif "top" in self.position: + position[1] = rect[1] + rect[3] + offset[1] = -offset[1] - 0.1 + va = "top" + elif "middle" in self.position: + position[1] += rect[3] / 2 + va = "center" + else: + raise RuntimeError( + f"Unknown positioning string: {self.position}" + ) + + ax.annotate( + self.label, + xy=position, + xycoords="data", + xytext=offset, + textcoords="offset points", + size=self.fontsize, + bbox=self.bbox, + horizontalalignment=ha, + verticalalignment=va, + ) + + return rectangle + + +class Text(Plate): + """ + A subclass of plate to writing text using grid coordinates. Any **kwargs + are passed through to :class:`PGM.Plate`. + + :param x: + The x-coordinate of the text in *model units*. + + :param y: + The y-coordinate of the text. + + :param label: + A string to write. + + :param fontsize: (optional) + The fontsize to use. + + """ + + def __init__(self, x, y, label, fontsize=None): + self.rect = [x, y, 0.0, 0.0] + self.label = label + self.fontsize = fontsize + self.label_offset = [0.0, 0.0] + self.bbox = {"fc": "none", "ec": "none"} + self.rect_params = {"ec": "none"} + + super().__init__( + rect=self.rect, + label=self.label, + label_offset=self.label_offset, + fontsize=self.fontsize, + rect_params=self.rect_params, + bbox=self.bbox, + ) diff --git a/test/example.py b/test/example.py new file mode 100644 index 0000000..703fdda --- /dev/null +++ b/test/example.py @@ -0,0 +1,40 @@ +import daft +import matplotlib as mpl +import matplotlib.pyplot as plt + +# Colors. +no_circle = {"ec": "#fff"} +pgm = daft.PGM(grid_unit=1.5, node_unit=1, dpi=200) + +# x_offset, y_offset = -0.5, 0.5 # adjust the coordinates of the nodes +x_offset, y_offset = 0, 0 +pgm.add_node("z1", r"$z_1$", 1 + x_offset, 1 + y_offset) +pgm.add_node("z2", r"$z_2$", 2 + x_offset, 1 + y_offset) +pgm.add_node("z...", r"$\cdots$", 3 + x_offset, 1 + y_offset, plot_params=no_circle) +pgm.add_node("zT", r"$z_T$", 4 + x_offset, 1 + y_offset) +pgm.add_node("y1", r"$x_1$", 1 + x_offset, 2.3 + y_offset) +pgm.add_node("y2", r"$x_2$", 2 + x_offset, 2.3 + y_offset) +pgm.add_node("y...", r"$\cdots$", 3 + x_offset, 2.3 + y_offset, plot_params=no_circle) +pgm.add_node("yT", r"$x_T$", 4 + x_offset, 2.3 + y_offset) +pgm.add_node("x1", r"$x_1$", 1 + x_offset, y_offset) +pgm.add_node("x2", r"$x_2$", 2 + x_offset, y_offset) +pgm.add_node("x...", r"$\cdots$", 3 + x_offset, y_offset, plot_params=no_circle) +pgm.add_node("xT", r"$x_T$", 4 + x_offset, y_offset) + +# Edges. +# pgm.add_edge("z1", "z2") +# pgm.add_edge("z2", "z...") +# pgm.add_edge("z...", "zT") +pgm.add_edge("x1", "z1") +pgm.add_edge("x2", "z2") +pgm.add_edge("xT", "zT") +pgm.add_edge("z1", "y2") +pgm.add_edge("z1", "yT") +pgm.add_edge("z2", "yT") + +# Render and save. +# pgm.render() +# plt.tight_layout() +# plt.show() +pgm.show() +# pgm.figure.savefig("../src/nar.png") From 6402ea9dcdd3b27d47306bbca72655636cc4d7ff Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 15:17:00 -0500 Subject: [PATCH 02/15] add typing, some refactoring Signed-off-by: David Fulford --- src/daft/__init__.py | 14 +-- src/daft/_exceptions.py | 2 +- src/daft/_types.py | 78 +++++++++++++++- src/daft/_utils.py | 46 ++++----- src/daft/edge.py | 49 ++++++---- src/daft/node.py | 112 ++++++++++++---------- src/daft/{_core.py => pgm.py} | 171 ++++++++++++++++++---------------- src/daft/plate.py | 73 +++++++++------ 8 files changed, 341 insertions(+), 204 deletions(-) rename src/daft/{_core.py => pgm.py} (77%) diff --git a/src/daft/__init__.py b/src/daft/__init__.py index 6061c3f..c8af2a4 100644 --- a/src/daft/__init__.py +++ b/src/daft/__init__.py @@ -2,18 +2,20 @@ from importlib.metadata import version as get_distribution -from .Plate import Plate, Text - -from . import _core, _exceptions, _utils -from ._core import PGM, Node, Edge +from . import node, edge, pgm, plate, _exceptions, _utils, _types +from .pgm import PGM +from .node import Node +from .edge import Edge +from .plate import Plate, Text from ._exceptions import SameLocationError -from ._utils import _rendering_context, _pop_multiple +from ._utils import _RenderingContext, _pop_multiple __version__ = get_distribution("daft") __all__ = [] -__all__ += _core.__all__ +__all__ += pgm.__all__ __all__ += node.__all__ __all__ += edge.__all__ +__all__ += plate.__all__ __all__ += _exceptions.__all__ __all__ += _utils.__all__ __all__ += _types.__all__ diff --git a/src/daft/_exceptions.py b/src/daft/_exceptions.py index dc40c8a..4de4b98 100644 --- a/src/daft/_exceptions.py +++ b/src/daft/_exceptions.py @@ -19,4 +19,4 @@ def __init__(self, edge: 'Edge') -> None: super().__init__(self.message) -from ._core import Edge +from .edge import Edge diff --git a/src/daft/_types.py b/src/daft/_types.py index 2ff1d20..9a2d926 100644 --- a/src/daft/_types.py +++ b/src/daft/_types.py @@ -4,11 +4,81 @@ import numpy as np from numpy.typing import NDArray, ArrayLike -from typing import Any, Annotated, Literal, TypeVar +from typing import Any, Annotated, Literal, TypeVar, TypedDict -DType = TypeVar("DType", bound=np.generic) - -NDArray2 = Annotated[NDArray[DType], Literal[2]] +T = TypeVar('T') NDArrayF = NDArray[np.float64] NDArrayI = NDArray[np.int64] + +# Tuple2 = tuple[T, T] | list[T] +Tuple2 = tuple[T, T] +Tuple4 = tuple[T, T, T, T] + +Tuple2F = Tuple2[float] +Tuple4F = Tuple4[float] + +Shape = Literal["ellipse", "rectangle"] + +Position = Literal[ + "bottom left", + "bottom center", + "bottom right", + "middle left", + "middle center", + "middle right", + "top left", + "top center", + "top right" +] + + +class PlotParams(TypedDict): + lw: str + linewidth: str + ec: str + edgecolor: str + fc: str + facecolor: str + alpha: float + + +class LabelParams(TypedDict): + va: str + verticalalignment: str + ha: str + horizontalalignment: str + ma: str + + +RectParams = TypedDict( + "RectParams", { + "ec": str, + "edgecolor": str, + "fc": str, + "facecolor": str, + "lw": str, + "linewidth": str, + }, + total=False +) + + +class CTX_Kwargs(TypedDict): + shape: NDArrayF + origin: NDArrayF + grid_unit: float + node_unit: float + observed_style: str + alternate_style: str + line_width: float + node_ec: str + node_fc: str + plate_fc: str + directed: bool + aspect: float + label_params: LabelParams | None + dpi: int | None + + +AnyDict = dict[str, Any] | PlotParams | LabelParams | RectParams | CTX_Kwargs diff --git a/src/daft/_utils.py b/src/daft/_utils.py index 7d41243..2c882c2 100644 --- a/src/daft/_utils.py +++ b/src/daft/_utils.py @@ -5,8 +5,12 @@ import matplotlib.pyplot as plt import numpy as np +from typing import Any, Literal, cast -class _rendering_context: +from ._types import NDArrayF, CTX_Kwargs, LabelParams, AnyDict + + +class _RenderingContext: """ :param shape: The number of rows and columns in the grid. @@ -56,7 +60,7 @@ class _rendering_context: """ - def __init__(self, **kwargs): + def __init__(self, kwargs: CTX_Kwargs) -> None: # Save the style defaults. self.line_width = kwargs.get("line_width", 1.0) @@ -81,8 +85,8 @@ def __init__(self, **kwargs): self.padding = 0.1 self.shp_fig_scale = 2.54 - self.shape = np.array(kwargs.get("shape", [1, 1]), dtype=np.float64) - self.origin = np.array(kwargs.get("origin", [0, 0]), dtype=np.float64) + self.shape = kwargs.get("shape", np.atleast_1d((1, 1))) + self.origin = kwargs.get("origin", (0, 0)) self.grid_unit = kwargs.get("grid_unit", 2.0) self.figsize = self.grid_unit * self.shape / self.shp_fig_scale @@ -92,22 +96,22 @@ def __init__(self, **kwargs): self.plate_fc = kwargs.get("plate_fc", "w") self.directed = kwargs.get("directed", True) self.aspect = kwargs.get("aspect", 1.0) - self.label_params = dict(kwargs.get("label_params", {}) or {}) + self.label_params = kwargs.get("label_params", cast(LabelParams, {})) self.dpi = kwargs.get("dpi", None) # Initialize the figure to ``None`` to handle caching later. - self._figure = None - self._ax = None + self._figure: plt.Figure | None = None + self._ax: plt.Axis | None = None - def reset_shape(self, shape, adj_origin=False): + def reset_shape(self, shape: NDArrayF, adj_origin: bool = False) -> None: """Reset the shape and figure size.""" # shape is scaled by grid_unit # so divide by grid_unit for proper shape self.shape = shape / self.grid_unit + self.padding self.figsize = self.grid_unit * self.shape / self.shp_fig_scale - def reset_origin(self, origin, adj_shape=False): + def reset_origin(self, origin: NDArrayF, adj_shape: bool = False) -> None: """Reset the origin.""" # origin is scaled by grid_unit # so divide by grid_unit for proper shape @@ -116,28 +120,28 @@ def reset_origin(self, origin, adj_shape=False): self.shape -= self.origin self.figsize = self.grid_unit * self.shape / self.shp_fig_scale - def reset_figure(self): + def reset_figure(self) -> None: """Reset the figure.""" self.close() - def close(self): + def close(self) -> None: """Close the figure if it is set up.""" if self._figure is not None: plt.close(self._figure) self._figure = None self._ax = None - def figure(self): + def figure(self) -> plt.Figure: """Return the current figure else create a new one.""" if self._figure is not None: return self._figure - args = {"figsize": self.figsize} + args: dict[str, Any] = {"figsize": self.figsize} if self.dpi is not None: args["dpi"] = self.dpi self._figure = plt.figure(**args) return self._figure - def ax(self): + def ax(self) -> plt.Axes: """Return the current axes else create a new one.""" if self._ax is not None: return self._ax @@ -155,19 +159,15 @@ def ax(self): return self._ax - def convert(self, *xy): + def convert(self, x: float, y: float) -> tuple[float, float]: """ Convert from model coordinates to plot coordinates. """ - if len(xy) != 2: - raise ValueError( - "You must provide two coordinates to `convert()`." - ) - return self.grid_unit * (np.atleast_1d(xy) - self.origin) + return self.grid_unit * (x - self.origin[0]), self.grid_unit * (y - self.origin[1]) -def _pop_multiple(_dict, default, *args): +def _pop_multiple(_dict: AnyDict, default: Any, *args: str) -> Any: """ A helper function for dealing with the way that matplotlib annoyingly allows multiple keyword arguments. For example, ``edgecolor`` and ``ec`` @@ -190,10 +190,10 @@ def _pop_multiple(_dict, default, *args): if len(args) == 0: raise ValueError("You must provide at least one argument to `pop()`.") - results = [] + results: list[Any] = [] for arg in args: try: - results.append((arg, _dict.pop(arg))) + results.append((arg, _dict.pop(arg))) # type: ignore[misc] except KeyError: pass diff --git a/src/daft/edge.py b/src/daft/edge.py index 5955375..b7d7e9e 100644 --- a/src/daft/edge.py +++ b/src/daft/edge.py @@ -2,7 +2,14 @@ __all__ = ["Edge"] -from ._utils import _pop_multiple + +from matplotlib.lines import Line2D +from matplotlib.patches import FancyArrow + +from typing import Any, cast + +from ._utils import _pop_multiple, _RenderingContext +from ._types import Tuple4F, PlotParams, LabelParams class Edge: @@ -44,15 +51,15 @@ class Edge: def __init__( self, - node1, - node2, - directed=True, - label=None, - xoffset=0, - yoffset=0.1, - plot_params=None, - label_params=None, - ): + node1: 'Node', + node2: 'Node', + directed: bool = True, + label: str | None = None, + xoffset: float = 0, + yoffset: float = 0.1, + plot_params: PlotParams | None = None, + label_params: LabelParams | None = None, + ) -> None: self.node1 = node1 self.node2 = node2 self.directed = directed @@ -62,7 +69,7 @@ def __init__( self.plot_params = dict(plot_params) if plot_params else {} self.label_params = dict(label_params) if label_params else {} - def _get_coords(self, ctx): + def _get_coords(self, ctx: _RenderingContext) -> Tuple4F: """ Get the coordinates of the line. @@ -83,7 +90,7 @@ def _get_coords(self, ctx): return x3, y3, x4 - x3, y4 - y3 - def render(self, ctx): + def render(self, ctx: _RenderingContext) -> FancyArrow | list[Line2D]: """ Render the edge in the given axes. @@ -105,13 +112,13 @@ def render(self, ctx): x, y, dx, dy = self._get_coords(ctx) ax.annotate( self.label, - [x + 0.5 * dx + self.xoffset, y + 0.5 * dy + self.yoffset], + xy=(x + 0.5 * dx + self.xoffset, y + 0.5 * dy + self.yoffset), xycoords="data", - xytext=[0, 3], + xytext=(0, 3), textcoords="offset points", ha="center", va="center", - **self.label_params, + **cast(dict[str, Any], self.label_params) ) if self.directed: @@ -133,7 +140,7 @@ def render(self, ctx): *self._get_coords(ctx), width=0, length_includes_head=True, - **plot_params, + **cast(dict[str, Any], plot_params) ) # Add the arrow to the axes. @@ -142,6 +149,7 @@ def render(self, ctx): else: print(args[2], args[3]) + return [] else: plot_params["color"] = plot_params.get("color", "k") @@ -150,5 +158,12 @@ def render(self, ctx): x, y, dx, dy = self._get_coords(ctx) # Plot the line. - line = ax.plot([x, x + dx], [y, y + dy], **plot_params) + line = ax.plot( + (x, x + dx), + (y, y + dy), + **cast(dict[str, Any], plot_params) + ) return line + + +from .node import Node diff --git a/src/daft/node.py b/src/daft/node.py index d30d9ff..9bb939f 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -7,8 +7,10 @@ import numpy as np -from ._utils import _rendering_context, _pop_multiple -from ._exceptions import SameLocationError +from typing import Any, Literal, TypedDict, cast + +from ._utils import _pop_multiple, _RenderingContext +from ._types import Tuple2F, CTX_Kwargs, PlotParams, LabelParams, Shape class Node: @@ -70,21 +72,21 @@ class Node: def __init__( self, - name, - content, - x, - y, - scale=1.0, - aspect=None, - observed=False, - fixed=False, - alternate=False, - offset=(0.0, 0.0), - fontsize=None, - plot_params=None, - label_params=None, - shape="ellipse", - ): + name: str, + content: str, + x: float, + y: float, + scale: float = 1.0, + aspect: float | None = None, + observed: bool = False, + fixed: bool = False, + alternate: bool = False, + offset: Tuple2F = (0.0, 0.0), + fontsize: float | None = None, + plot_params: PlotParams | None = None, + label_params: LabelParams | None = None, + shape: Shape = "ellipse", + ) -> None: # Check Node style. # Iterable is consumed, so first condition checks if two or more are # true @@ -105,12 +107,13 @@ def __init__( self.content = content # Coordinates and dimensions. - self.x, self.y = float(x), float(y) + self.x = float(x) + self.y = float(y) self.scale = float(scale) if self.fixed: self.scale /= 6.0 if aspect is not None: - self.aspect = float(aspect) + self.aspect: float | None = float(aspect) else: self.aspect = aspect @@ -118,11 +121,11 @@ def __init__( self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"] # Display parameters. - self.plot_params = dict(plot_params) if plot_params else {} + self.plot_params = cast(PlotParams, dict(plot_params) if plot_params else {}) # Text parameters. - self.offset = list(offset) - self.label_params = dict(label_params) if label_params else None + self.offset = offset + self.label_params = cast(LabelParams | None, dict(label_params) if label_params else None) # Shape if shape in ["ellipse", "rectangle"]: @@ -131,7 +134,7 @@ def __init__( print("Warning: wrong shape value, set to ellipse instead") self.shape = "ellipse" - def render(self, ctx) -> Ellipse | Rectangle: + def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: """ Render the node. @@ -144,19 +147,19 @@ def render(self, ctx) -> Ellipse | Rectangle: ax = ctx.ax() # Resolve the plotting parameters. - plot_params = dict(self.plot_params) + plot_params = cast(PlotParams, dict(self.plot_params)) plot_params["lw"] = _pop_multiple( - plot_params, ctx.line_width, "lw", "linewidth" + cast(dict[str, Any], plot_params), ctx.line_width, "lw", "linewidth" ) plot_params["ec"] = plot_params["edgecolor"] = _pop_multiple( - plot_params, ctx.node_ec, "ec", "edgecolor" + cast(dict[str, Any], plot_params), ctx.node_ec, "ec", "edgecolor" ) - fc_is_set = "fc" in plot_params or "facecolor" in plot_params + fc_is_set = "fc" in plot_params or "facecolor" in plot_params # type: ignore[unreachable] plot_params["fc"] = _pop_multiple( - plot_params, ctx.node_fc, "fc", "facecolor" + cast(dict[str, Any], plot_params), ctx.node_fc, "fc", "facecolor" ) fc = plot_params["fc"] @@ -164,9 +167,9 @@ def render(self, ctx) -> Ellipse | Rectangle: # And the label parameters. if self.label_params is None: - label_params = dict(ctx.label_params) + label_params = cast(LabelParams, ctx.label_params) else: - label_params = dict(self.label_params) + label_params = cast(LabelParams, self.label_params) label_params["va"] = _pop_multiple( label_params, "center", "va", "verticalalignment" @@ -180,11 +183,12 @@ def render(self, ctx) -> Ellipse | Rectangle: scale = self.scale if self.fixed: # MAGIC: These magic numbers should depend on the grid/node units. - self.offset[1] += 6 + self.offset = (self.offset[0], self.offset[1] + 6) label_params["va"] = "baseline" - label_params.pop("verticalalignment", None) - label_params.pop("ma", None) + _label_params = cast(dict[str, Any], label_params) + _label_params.pop("verticalalignment", None) + _label_params.pop("ma", None) if not fc_is_set: plot_params["fc"] = "k" @@ -201,9 +205,9 @@ def render(self, ctx) -> Ellipse | Rectangle: elif self.alternate and not self.fixed: style = ctx.alternate_style else: - style = False + style = "none" - if style: + if style != "none": # Update the plotting parameters depending on the style of # observed node. h = float(diameter) @@ -220,7 +224,7 @@ def render(self, ctx) -> Ellipse | Rectangle: # Draw the background ellipse. if self.shape == "ellipse": - bg = Ellipse( + bg: Ellipse | Rectangle = Ellipse( xy=ctx.convert(self.x, self.y), width=w, height=h, @@ -229,11 +233,16 @@ def render(self, ctx) -> Ellipse | Rectangle: elif self.shape == "rectangle": # Adapt to make Rectangle the same api than ellipse wi = w - xy = ctx.convert(self.x, self.y) - xy[0] = xy[0] - wi / 2.0 - xy[1] = xy[1] - h / 2.0 + x, y = ctx.convert(self.x, self.y) + x -= wi / 2.0 + y -= h / 2.0 - bg = Rectangle(xy=xy, width=wi, height=h, **plot_params) + bg = Rectangle( + xy=(x, y), + width=wi, + height=h, + **plot_params, + ) else: # Should never append raise ( @@ -252,7 +261,7 @@ def render(self, ctx) -> Ellipse | Rectangle: plot_params["fc"] = "none" if self.shape == "ellipse": - el = Ellipse( + el: Ellipse | Rectangle = Ellipse( xy=ctx.convert(self.x, self.y), width=diameter * aspect, height=diameter, @@ -261,11 +270,16 @@ def render(self, ctx) -> Ellipse | Rectangle: elif self.shape == "rectangle": # Adapt to make Rectangle the same api than ellipse wi = diameter * aspect - xy = ctx.convert(self.x, self.y) - xy[0] = xy[0] - wi / 2.0 - xy[1] = xy[1] - diameter / 2.0 + x, y = ctx.convert(self.x, self.y) + x -= wi / 2.0 + y -= diameter / 2.0 - el = Rectangle(xy=xy, width=wi, height=diameter, **plot_params) + el = Rectangle( + xy=(x, y), + width=wi, + height=diameter, + **plot_params, + ) else: # Should never append raise ( @@ -285,12 +299,12 @@ def render(self, ctx) -> Ellipse | Rectangle: xytext=self.offset, textcoords="offset points", size=self.fontsize, - **label_params, + **_label_params, ) return el - def get_frontier_coord(self, target_xy, ctx, edge): + def get_frontier_coord(self, target_xy: Tuple2F, ctx: _RenderingContext, edge: 'Edge') -> Tuple2F: """ Get the coordinates of the point of intersection between the shape of the node and a line starting from the center of the node to an @@ -366,3 +380,7 @@ def get_frontier_coord(self, target_xy, ctx, edge): else: # Should never append raise ValueError("Wrong shape in object causes an error") + + +from .edge import Edge +from ._exceptions import SameLocationError diff --git a/src/daft/_core.py b/src/daft/pgm.py similarity index 77% rename from src/daft/_core.py rename to src/daft/pgm.py index da4209f..b5bc70e 100644 --- a/src/daft/_core.py +++ b/src/daft/pgm.py @@ -1,6 +1,6 @@ """Code for Daft""" -__all__ = ["PGM", ] +__all__ = ["PGM"] # TODO: should Text be added? import matplotlib.pyplot as plt @@ -8,15 +8,13 @@ import numpy as np -from typing import Any -from numpy.typing import NDArray, ArrayLike - +from typing import Any, Literal, cast from .node import Node from .edge import Edge from .plate import Plate, Text -from ._utils import _rendering_context -from ._types import NDArray2, NDArrayF, NDArrayI +from ._utils import _RenderingContext +from ._types import Tuple2F, NDArrayF, Shape, Position, CTX_Kwargs, PlotParams, LabelParams # pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines @@ -77,8 +75,8 @@ class PGM: def __init__( self, - shape: tuple[float, float] | list[float] | NDArrayF | None = None, - origin: ArrayLike | None = None, + shape: Tuple2F | None = None, + origin: Tuple2F | None = None, grid_unit: float = 2.0, node_unit: float = 1.0, observed_style: str = "shaded", @@ -89,26 +87,33 @@ def __init__( plate_fc: str = "w", directed: bool = True, aspect: float = 1.0, - label_params: dict[str, Any] | None = None, + label_params: LabelParams | None = None, dpi: int | None = None, ) -> None: self._nodes: dict[str, Node] = {} - self._edges: dict[str, Edge] = [] - self._plates: dict[str, Plate] = [] + self._edges: list[Edge] = [] + self._plates: list[Plate] = [] self._dpi = dpi # if shape and origin are not given, pass a default # and we will determine at rendering time - self.shape = shape - self.origin = origin if shape is None: - shape = [1, 1] + _shape: Tuple2F = (1, 1) + self.shape = None + else: + _shape = shape + self.shape = tuple(shape) + if origin is None: - origin = [0, 0] + _origin: Tuple2F = (0, 0) + self.origin = None + else: + _origin = origin + self.origin = tuple(origin) - self._ctx = _rendering_context( - shape=shape, - origin=origin, + self._ctx = _RenderingContext(CTX_Kwargs( + shape=np.asarray(_shape, dtype=np.float64), + origin=np.asarray(_origin, dtype=np.float64), grid_unit=grid_unit, node_unit=node_unit, observed_style=observed_style, @@ -120,32 +125,32 @@ def __init__( directed=directed, aspect=aspect, label_params=label_params, - dpi=dpi, - ) + dpi=dpi + )) - def __enter__(self): + def __enter__(self) -> "PGM": return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self._ctx.close() def add_node( self, - node, - content="", - x=0, - y=0, - scale=1.0, - aspect=None, - observed=False, - fixed=False, - alternate=False, - offset=(0.0, 0.0), - fontsize=None, - plot_params=None, - label_params=None, - shape="ellipse", - ): + node: Node, + content: str = "", + x: float = 0, + y: float = 0, + scale: float = 1.0, + aspect: float | None = None, + observed: bool = False, + fixed: bool = False, + alternate: bool = False, + offset: Tuple2F = (0, 0), + fontsize: float | None = None, + plot_params: PlotParams | None = None, + label_params: LabelParams | None = None, + shape: Shape = "ellipse", + ) -> Node: """ Add a :class:`Node` to the model. @@ -205,7 +210,7 @@ def add_node( if isinstance(node, Node): _node = node else: - _node = Node( + _node = Node( # type: ignore[unreachable] node, content, x, @@ -228,16 +233,16 @@ def add_node( def add_edge( self, - name1, - name2, - directed=None, - xoffset=0.0, - yoffset=0.1, - label=None, - plot_params=None, - label_params=None, - **kwargs, # pylint: disable=unused-argument - ): + name1: str, + name2: str, + directed: bool | None = None, + xoffset: float = 0.0, + yoffset: float = 0.1, + label: str | None = None, + plot_params: PlotParams | None = None, + label_params: LabelParams | None = None, + **kwargs: dict[str, Any], # pylint: disable=unused-argument + ) -> Edge: """ Construct an :class:`Edge` between two named :class:`Node` objects. @@ -291,15 +296,15 @@ def add_edge( def add_plate( self, - plate, - label=None, - label_offset=(5, 5), - shift=0, - position="bottom left", - fontsize=None, - rect_params=None, - bbox=None, - ): + plate: Plate, + label: str | None = None, + label_offset: Tuple2F = (5, 5), + shift: float = 0, + position: Position = "bottom left", + fontsize: float | None = None, + rect_params: dict[str, Any] | None = None, + bbox: bool | None = None, + ) -> None: """ Add a :class:`Plate` object to the model. @@ -333,7 +338,7 @@ def add_plate( if isinstance(plate, Plate): _plate = plate else: - _plate = Plate( + _plate = Plate( # type: ignore[unreachable] plate, label, label_offset, @@ -346,7 +351,7 @@ def add_plate( self._plates.append(_plate) - def add_text(self, x, y, label, fontsize=None): + def add_text(self, x: float, y: float, label: str, fontsize: float | None = None) -> None: """ A subclass of plate to writing text using grid coordinates. Any ``**kwargs`` are passed through to :class:`PGM.Plate`. @@ -365,7 +370,12 @@ def add_text(self, x, y, label, fontsize=None): """ - text = Text(x=x, y=y, label=label, fontsize=fontsize) + text = Text( + x=x, + y=y, + label=label, + fontsize=fontsize + ) self._plates.append(text) return None @@ -386,35 +396,35 @@ def render(self, dpi: int | None = None) -> plt.Axes: else: self._ctx.dpi = dpi - def get_max(maxsize, artist): - if isinstance(artist, Ellipse): + def get_max(maxsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: + if isinstance(patch, Ellipse): maxsize = np.maximum( maxsize, - artist.center - + np.array([artist.width, artist.height]) / 2, + patch.center + + np.array([patch.width, patch.height]) / 2, dtype=np.float64, ) - elif isinstance(artist, Rectangle): + elif isinstance(patch, Rectangle): maxsize = np.maximum( maxsize, - np.array([artist._x0, artist._y0], dtype=np.float64) - + np.array([artist._width, artist._height]), + np.array([patch._x0, patch._y0], dtype=np.float64) # type: ignore[attr-defined] + + np.array([patch._width, patch._height]), # type: ignore[attr-defined] dtype=np.float64, ) return maxsize - def get_min(minsize, artist): - if isinstance(artist, Ellipse): + def get_min(minsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: + if isinstance(patch, Ellipse): minsize = np.minimum( minsize, - artist.center - - np.array([artist.width, artist.height]) / 2, + patch.center + - np.array([patch.width, patch.height]) / 2, dtype=np.float64, ) - elif isinstance(artist, Rectangle): + elif isinstance(patch, Rectangle): minsize = np.minimum( minsize, - np.array([artist._x0, artist._y0], dtype=np.float64), + np.array([patch._x0, patch._y0], dtype=np.float64), # type: ignore[attr-defined] ) return minsize @@ -424,12 +434,15 @@ def get_min(minsize, artist): maxsize = np.copy(self._ctx.origin) for plate in self._plates: - artist = plate.render(self._ctx) + artist: Ellipse | Rectangle = plate.render(self._ctx) maxsize = get_max(maxsize, artist) for name in self._nodes: if self._nodes[name].fixed: - self._nodes[name].offset[1] -= 12.5 + offx, offy = self._nodes[name].offset + offy -= 12.5 + self._nodes[name].offset = (offx, offy) + artist = self._nodes[name].render(self._ctx) maxsize = get_max(maxsize, artist) @@ -464,16 +477,16 @@ def get_min(minsize, artist): return self.ax @property - def figure(self): + def figure(self) -> plt.Figure: """Figure as a property.""" return self._ctx.figure() @property - def ax(self): + def ax(self) -> plt.Axes: """Axes as a property.""" return self._ctx.ax() - def show(self, *args, dpi=None, **kwargs): + def show(self, *args: Any, dpi: int | None = None, **kwargs: Any) -> None: """ Wrapper on :class:`PGM.render()` that calls `matplotlib.show()` immediately after. @@ -486,7 +499,7 @@ def show(self, *args, dpi=None, **kwargs): self.render(dpi=dpi) plt.show(*args, **kwargs) - def savefig(self, fname, *args, **kwargs): + def savefig(self, fname: str, *args: Any, **kwargs: Any) -> None: """ Wrapper on ``matplotlib.Figure.savefig()`` that sets default image padding using ``bbox_inchaes = tight``. diff --git a/src/daft/plate.py b/src/daft/plate.py index b76bab3..d50b889 100644 --- a/src/daft/plate.py +++ b/src/daft/plate.py @@ -1,13 +1,17 @@ +"""Daft errors""" + +__all__: list[str] = [] + + import matplotlib as mpl from matplotlib.patches import Rectangle import numpy as np -from ._utils import _pop_multiple - -# Move exception import to end of file to resolve circular dependency -from ._exceptions import SameLocationError +from typing import Any, cast +from ._utils import _pop_multiple, _RenderingContext +from ._types import Tuple2F, Tuple4F, Position, RectParams class Plate: @@ -50,32 +54,32 @@ class Plate: def __init__( self, - rect, - label=None, - label_offset=(5, 5), - shift=0, - position="bottom left", - fontsize=None, - rect_params=None, - bbox=None, - ): + rect: Tuple4F, + label: str | None = None, + label_offset: Tuple2F = (5, 5), + shift: float = 0, + position: Position = "bottom left", + fontsize: float | None = None, + rect_params: RectParams | None = None, + bbox: dict[str, Any] | None = None, + ) -> None: self.rect = rect self.label = label self.label_offset = label_offset self.shift = shift if fontsize is not None: - self.fontsize = fontsize + self.fontsize: float | None = fontsize else: self.fontsize = mpl.rcParams["font.size"] if rect_params is not None: - self.rect_params = dict(rect_params) + self.rect_params: RectParams | None = rect_params else: self.rect_params = None if bbox is not None: - self.bbox = dict(bbox) + self.bbox: dict[str, Any] | None = dict(bbox) # Set the awful default blue color to transparent if "fc" not in self.bbox.keys(): @@ -85,7 +89,7 @@ def __init__( self.position = position - def render(self, ctx): + def render(self, ctx: _RenderingContext) -> Rectangle: """ Render the plate in the given axes. @@ -97,8 +101,8 @@ def render(self, ctx): shift = np.array([0, self.shift], dtype=np.float64) rect = np.atleast_1d(self.rect) - bottom_left = ctx.convert(*(rect[:2] + shift)) - top_right = ctx.convert(*(rect[:2] + rect[2:])) + bottom_left = np.atleast_1d(ctx.convert(*(rect[:2] + shift))) + top_right = np.atleast_1d(ctx.convert(*(rect[:2] + rect[2:]))) rect = np.concatenate([bottom_left, top_right - bottom_left]) if self.rect_params is not None: @@ -106,14 +110,26 @@ def render(self, ctx): else: rect_params = {} - rect_params["ec"] = _pop_multiple(rect_params, "k", "ec", "edgecolor") + rect_params["ec"] = _pop_multiple( + rect_params, "k", "ec", "edgecolor" + ) rect_params["fc"] = _pop_multiple( rect_params, ctx.plate_fc, "fc", "facecolor" ) rect_params["lw"] = _pop_multiple( rect_params, ctx.line_width, "lw", "linewidth" ) - rectangle = Rectangle(rect[:2], *rect[2:], **rect_params) + + x: float + y: float + x, y = rect[:2] + + rectangle = Rectangle( + xy=(x, y), + width=x, + height=y, + **rect_params + ) ax.add_artist(rectangle) @@ -150,11 +166,14 @@ def render(self, ctx): f"Unknown positioning string: {self.position}" ) + posx, posy = position + offx, offy = offset + ax.annotate( self.label, - xy=position, + xy=(posx, posy), xycoords="data", - xytext=offset, + xytext=(offx, offy), textcoords="offset points", size=self.fontsize, bbox=self.bbox, @@ -184,13 +203,13 @@ class Text(Plate): """ - def __init__(self, x, y, label, fontsize=None): - self.rect = [x, y, 0.0, 0.0] + def __init__(self, x: float, y: float, label: str, fontsize: float | None = None) -> None: + self.rect = (x, y, 0.0, 0.0) self.label = label self.fontsize = fontsize - self.label_offset = [0.0, 0.0] + self.label_offset = (0.0, 0.0) + self.rect_params = cast(RectParams, {"ec": "none"}) self.bbox = {"fc": "none", "ec": "none"} - self.rect_params = {"ec": "none"} super().__init__( rect=self.rect, From d3890e96b8e2cf6ad3318551a9b86802ec4419f1 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 15:34:18 -0500 Subject: [PATCH 03/15] debugging Signed-off-by: David Fulford --- src/daft/_utils.py | 2 +- src/daft/node.py | 8 +++-- test/test_daft.py | 2 +- test/test_examples.py | 74 ++++++++++++++++++++++--------------------- 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/daft/_utils.py b/src/daft/_utils.py index 2c882c2..f3e1443 100644 --- a/src/daft/_utils.py +++ b/src/daft/_utils.py @@ -96,7 +96,7 @@ def __init__(self, kwargs: CTX_Kwargs) -> None: self.plate_fc = kwargs.get("plate_fc", "w") self.directed = kwargs.get("directed", True) self.aspect = kwargs.get("aspect", 1.0) - self.label_params = kwargs.get("label_params", cast(LabelParams, {})) + self.label_params = cast(LabelParams, kwargs.get("label_params", {}) or {}) self.dpi = kwargs.get("dpi", None) diff --git a/src/daft/node.py b/src/daft/node.py index 9bb939f..46b8b83 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -186,9 +186,6 @@ def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: self.offset = (self.offset[0], self.offset[1] + 6) label_params["va"] = "baseline" - _label_params = cast(dict[str, Any], label_params) - _label_params.pop("verticalalignment", None) - _label_params.pop("ma", None) if not fc_is_set: plot_params["fc"] = "k" @@ -291,6 +288,11 @@ def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: # Reset the face color. plot_params["fc"] = fc + # pop extra params + _label_params = cast(dict[str, Any], label_params) + _label_params.pop("verticalalignment", None) + _label_params.pop("ma", None) + # Annotate the node. ax.annotate( self.content, diff --git a/test/test_daft.py b/test/test_daft.py index 2f31b02..be6e97c 100644 --- a/test/test_daft.py +++ b/test/test_daft.py @@ -63,7 +63,7 @@ def test_add_text(): pgm.add_text(x=0, y=0, label="text1") plate = pgm._plates[0] - assert plate.rect == [0, 0, 0, 0] + assert plate.rect == (0, 0, 0, 0) assert plate.label == "text1" diff --git a/test/test_examples.py b/test/test_examples.py index e4f5b02..da5b3af 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -3,9 +3,11 @@ import daft from matplotlib.testing.decorators import image_comparison +from typing import Any + @image_comparison(baseline_images=["bca"], extensions=["png"]) -def test_bca(): +def test_bca() -> None: pgm = daft.PGM() pgm.add_node("a", r"$a$", 1, 5) pgm.add_node("b", r"$b$", 1, 4) @@ -13,11 +15,11 @@ def test_bca(): pgm.add_plate([0.5, 2.25, 1, 1.25], label=r"data $n$") pgm.add_edge("a", "b") pgm.add_edge("b", "c") - pgm.render() + pgm.show() @image_comparison(baseline_images=["classic"], extensions=["png"]) -def test_classic(): +def test_classic() -> None: pgm = daft.PGM() # Hierarchical parameters. @@ -39,11 +41,11 @@ def test_classic(): # And a plate. pgm.add_plate([0.5, 0.5, 2, 1], label=r"$n = 1, \cdots, N$", shift=-0.1) - pgm.render() + pgm.show() @image_comparison(baseline_images=["deconvolution"], extensions=["png"]) -def test_deconvolution(): +def test_deconvolution() -> None: scale = 1.6 pgm = daft.PGM() @@ -147,11 +149,11 @@ def test_deconvolution(): pgm.add_text(x0, y0 + 5 * dy, "= pressure drop at sandface") # Render and save. - pgm.render() + pgm.show() @image_comparison(baseline_images=["exoplanets"], extensions=["png"]) -def test_exoplanets(): +def test_exoplanets() -> None: # Colors. p_color = {"ec": "#46a546"} s_color = {"ec": "#f89406"} @@ -186,11 +188,11 @@ def test_exoplanets(): pgm.add_plate([2, 0.5, 1, 1], label=r"pixel $j$", shift=-0.1) # Render and save. - pgm.render() + pgm.show() @image_comparison(baseline_images=["fixed"], extensions=["png"]) -def test_fixed(): +def test_fixed() -> None: pgm = daft.PGM(aspect=1.5, node_unit=1.75) pgm.add_node("unobs", r"Unobserved!", 1, 4) pgm.add_node("obs", r"Observed!", 1, 3, observed=True) @@ -198,11 +200,11 @@ def test_fixed(): pgm.add_node( "fixed", r"Fixed!", 1, 1, fixed=True, aspect=1.0, offset=[0, 5] ) - pgm.render() + pgm.show() @image_comparison(baseline_images=["gaia"], extensions=["png"]) -def test_gaia(): +def test_gaia() -> None: pgm = daft.PGM() pgm.add_node("omega", r"$\omega$", 2, 5) pgm.add_node("true", r"$\tilde{X}_n$", 2, 4) @@ -216,11 +218,11 @@ def test_gaia(): pgm.add_edge("alpha", "true") pgm.add_edge("Sigma", "sigma") pgm.add_edge("sigma", "obs") - pgm.render() + pgm.show() @image_comparison(baseline_images=["galex"], extensions=["png"]) -def test_galex(): +def test_galex() -> None: pgm = daft.PGM() wide = 1.5 verywide = 1.5 * wide @@ -299,11 +301,11 @@ def test_galex(): pgm.add_edge("star pos", "star adt") # done - pgm.render() + pgm.show() @image_comparison(baseline_images=["huey_p_newton"], extensions=["png"]) -def test_huey_p_newton(): +def test_huey_p_newton() -> None: pgm = daft.PGM() kx, ky = 1.5, 1.0 @@ -334,11 +336,11 @@ def test_huey_p_newton(): pgm.add_edge("sigman", "Yn") # Render and save. - pgm.render() + pgm.show() @image_comparison(baseline_images=["logo"], extensions=["png"]) -def test_logo(): +def test_logo() -> None: pgm = daft.PGM() pgm.add_node("d", r"$D$", 0.5, 0.5) pgm.add_node("a", r"$a$", 1.5, 0.5, observed=True) @@ -347,11 +349,11 @@ def test_logo(): pgm.add_edge("d", "a") pgm.add_edge("a", "f") pgm.add_edge("f", "t") - pgm.render() + pgm.show() @image_comparison(baseline_images=["mrf"], extensions=["png"]) -def test_mrf(): +def test_mrf() -> None: pgm = daft.PGM(node_unit=0.4, grid_unit=1, directed=False) for i, (xi, yi) in enumerate(itertools.product(range(1, 5), range(1, 5))): @@ -377,11 +379,11 @@ def test_mrf(): ]: pgm.add_edge(str(e[0]), str(e[1])) - pgm.render() + pgm.show() @image_comparison(baseline_images=["no_circles"], extensions=["png"]) -def test_no_circles(): +def test_no_circles() -> None: pgm = daft.PGM(node_ec="none") pgm.add_node("cloudy", r"cloudy", 3, 3) pgm.add_node("rain", r"rain", 2, 2) @@ -391,11 +393,11 @@ def test_no_circles(): pgm.add_edge("cloudy", "sprinkler") pgm.add_edge("rain", "wet") pgm.add_edge("sprinkler", "wet") - pgm.render() + pgm.show() @image_comparison(baseline_images=["no_gray"], extensions=["png"]) -def test_no_gray(): +def test_no_gray() -> None: pgm = daft.PGM(observed_style="inner") # Hierarchical parameters. @@ -418,12 +420,12 @@ def test_no_gray(): pgm.add_plate([0.5, 0.5, 2, 1], label=r"$n = 1, \ldots, N$", shift=-0.1) # Render and save. - pgm.render() + pgm.show() @image_comparison(baseline_images=["recursive"], extensions=["png"]) -def test_recursive(): - def recurse(pgm, nodename, level, c): +def test_recursive() -> None: + def recurse(pgm: Any, nodename: Any, level: Any, c: Any) -> Any: if level > 4: return nodename r = c // 2 @@ -473,11 +475,11 @@ def recurse(pgm, nodename, level, c): ) pgm.add_edge("output", "answer") - pgm.render() + pgm.show() @image_comparison(baseline_images=["thick_lines"], extensions=["png"]) -def test_thick_lines(): +def test_thick_lines() -> None: pgm = daft.PGM(line_width=2.5) # Hierarchical parameters. @@ -500,11 +502,11 @@ def test_thick_lines(): pgm.add_plate([0.5, 0.5, 2, 1], label=r"$n = 1, \cdots, N$", shift=-0.1) # Render and save. - pgm.render() + pgm.show() @image_comparison(baseline_images=["weaklensing"], extensions=["png"]) -def test_weaklensing(): +def test_weaklensing() -> None: pgm = daft.PGM() pgm.add_node("Omega", r"$\Omega$", -1, 4) pgm.add_node("gamma", r"$\gamma$", 0, 4) @@ -522,11 +524,11 @@ def test_weaklensing(): pgm.add_edge("x", "obs") pgm.add_edge("Sigma", "sigma") pgm.add_edge("sigma", "obs") - pgm.render() + pgm.show() -@image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy(): +@image_comparison(baseline_images=["no_circles"], extensions=["png"]) +def test_no_circles() -> None: pgm = daft.PGM() pgm.add_node("cloudy", r"cloudy", 3, 3, aspect=1.8) pgm.add_node("rain", r"rain", 2, 2, aspect=1.2) @@ -548,11 +550,11 @@ def test_wordy(): ) pgm.add_edge("rain", "wet") pgm.add_edge("sprinkler", "wet") - pgm.render() + pgm.show() @image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy(): +def test_wordy() -> None: pgm = daft.PGM() pgm.add_node("obs", r"$\epsilon^{obs}_n$", 2, 3, observed=True) pgm.add_node("true", r"$\epsilon^{true}_n$", 1, 3) @@ -569,4 +571,4 @@ def test_wordy(): pgm.add_edge("sigma", "obs") pgm.add_plate([0.5, 2.25, 2, 1.25], label=r"galaxies $n$") pgm.add_plate([0.25, 1.75, 2.5, 2.75], label=r"patches $m$") - pgm.render() + pgm.show() From da9993a917c02772e9581164956135762989ddd9 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 16:06:37 -0500 Subject: [PATCH 04/15] further debugging Signed-off-by: David Fulford --- src/daft/pgm.py | 4 ++-- src/daft/plate.py | 16 ++++++---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/daft/pgm.py b/src/daft/pgm.py index b5bc70e..14ffcbb 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -14,7 +14,7 @@ from .edge import Edge from .plate import Plate, Text from ._utils import _RenderingContext -from ._types import Tuple2F, NDArrayF, Shape, Position, CTX_Kwargs, PlotParams, LabelParams +from ._types import Tuple2F, NDArrayF, Shape, Position, CTX_Kwargs, PlotParams, LabelParams, RectParams # pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines @@ -302,7 +302,7 @@ def add_plate( shift: float = 0, position: Position = "bottom left", fontsize: float | None = None, - rect_params: dict[str, Any] | None = None, + rect_params: RectParams | None = None, bbox: bool | None = None, ) -> None: """ diff --git a/src/daft/plate.py b/src/daft/plate.py index d50b889..3a302a9 100644 --- a/src/daft/plate.py +++ b/src/daft/plate.py @@ -100,9 +100,9 @@ def render(self, ctx: _RenderingContext) -> Rectangle: ax = ctx.ax() shift = np.array([0, self.shift], dtype=np.float64) - rect = np.atleast_1d(self.rect) - bottom_left = np.atleast_1d(ctx.convert(*(rect[:2] + shift))) - top_right = np.atleast_1d(ctx.convert(*(rect[:2] + rect[2:]))) + rect = np.atleast_1d(np.asarray(self.rect, dtype=np.float64)) + bottom_left = np.asarray(ctx.convert(*(rect[:2] + shift)), dtype=np.float64) + top_right = np.asarray(ctx.convert(*(rect[:2] + rect[2:])), dtype=np.float64) rect = np.concatenate([bottom_left, top_right - bottom_left]) if self.rect_params is not None: @@ -120,14 +120,10 @@ def render(self, ctx: _RenderingContext) -> Rectangle: rect_params, ctx.line_width, "lw", "linewidth" ) - x: float - y: float - x, y = rect[:2] - rectangle = Rectangle( - xy=(x, y), - width=x, - height=y, + xy=(rect[0], rect[1]), + width=rect[2], + height=rect[3], **rect_params ) From 0efa73f2d01c331a4565d8811b599768468b3290 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 16:24:02 -0500 Subject: [PATCH 05/15] debug text alignment Signed-off-by: David Fulford --- src/daft/node.py | 6 ++++-- test/test_examples.py | 40 ++++++++++++++++++++-------------------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/daft/node.py b/src/daft/node.py index 46b8b83..189d0bd 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -3,6 +3,7 @@ __all__ = ["Node"] import matplotlib as mpl +from copy import deepcopy from matplotlib.patches import Ellipse, Rectangle import numpy as np @@ -167,9 +168,9 @@ def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: # And the label parameters. if self.label_params is None: - label_params = cast(LabelParams, ctx.label_params) + label_params = deepcopy(ctx.label_params) else: - label_params = cast(LabelParams, self.label_params) + label_params = deepcopy(self.label_params) label_params["va"] = _pop_multiple( label_params, "center", "va", "verticalalignment" @@ -190,6 +191,7 @@ def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: if not fc_is_set: plot_params["fc"] = "k" + diameter = ctx.node_unit * scale if self.aspect is not None: aspect = self.aspect diff --git a/test/test_examples.py b/test/test_examples.py index da5b3af..2d325e2 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -15,7 +15,7 @@ def test_bca() -> None: pgm.add_plate([0.5, 2.25, 1, 1.25], label=r"data $n$") pgm.add_edge("a", "b") pgm.add_edge("b", "c") - pgm.show() + pgm.render() @image_comparison(baseline_images=["classic"], extensions=["png"]) @@ -41,7 +41,7 @@ def test_classic() -> None: # And a plate. pgm.add_plate([0.5, 0.5, 2, 1], label=r"$n = 1, \cdots, N$", shift=-0.1) - pgm.show() + pgm.render() @image_comparison(baseline_images=["deconvolution"], extensions=["png"]) @@ -149,7 +149,7 @@ def test_deconvolution() -> None: pgm.add_text(x0, y0 + 5 * dy, "= pressure drop at sandface") # Render and save. - pgm.show() + pgm.render() @image_comparison(baseline_images=["exoplanets"], extensions=["png"]) @@ -188,7 +188,7 @@ def test_exoplanets() -> None: pgm.add_plate([2, 0.5, 1, 1], label=r"pixel $j$", shift=-0.1) # Render and save. - pgm.show() + pgm.render() @image_comparison(baseline_images=["fixed"], extensions=["png"]) @@ -200,7 +200,7 @@ def test_fixed() -> None: pgm.add_node( "fixed", r"Fixed!", 1, 1, fixed=True, aspect=1.0, offset=[0, 5] ) - pgm.show() + pgm.render() @image_comparison(baseline_images=["gaia"], extensions=["png"]) @@ -218,7 +218,7 @@ def test_gaia() -> None: pgm.add_edge("alpha", "true") pgm.add_edge("Sigma", "sigma") pgm.add_edge("sigma", "obs") - pgm.show() + pgm.render() @image_comparison(baseline_images=["galex"], extensions=["png"]) @@ -301,7 +301,7 @@ def test_galex() -> None: pgm.add_edge("star pos", "star adt") # done - pgm.show() + pgm.render() @image_comparison(baseline_images=["huey_p_newton"], extensions=["png"]) @@ -336,7 +336,7 @@ def test_huey_p_newton() -> None: pgm.add_edge("sigman", "Yn") # Render and save. - pgm.show() + pgm.render() @image_comparison(baseline_images=["logo"], extensions=["png"]) @@ -349,7 +349,7 @@ def test_logo() -> None: pgm.add_edge("d", "a") pgm.add_edge("a", "f") pgm.add_edge("f", "t") - pgm.show() + pgm.render() @image_comparison(baseline_images=["mrf"], extensions=["png"]) @@ -379,7 +379,7 @@ def test_mrf() -> None: ]: pgm.add_edge(str(e[0]), str(e[1])) - pgm.show() + pgm.render() @image_comparison(baseline_images=["no_circles"], extensions=["png"]) @@ -393,7 +393,7 @@ def test_no_circles() -> None: pgm.add_edge("cloudy", "sprinkler") pgm.add_edge("rain", "wet") pgm.add_edge("sprinkler", "wet") - pgm.show() + pgm.render() @image_comparison(baseline_images=["no_gray"], extensions=["png"]) @@ -420,7 +420,7 @@ def test_no_gray() -> None: pgm.add_plate([0.5, 0.5, 2, 1], label=r"$n = 1, \ldots, N$", shift=-0.1) # Render and save. - pgm.show() + pgm.render() @image_comparison(baseline_images=["recursive"], extensions=["png"]) @@ -475,7 +475,7 @@ def recurse(pgm: Any, nodename: Any, level: Any, c: Any) -> Any: ) pgm.add_edge("output", "answer") - pgm.show() + pgm.render() @image_comparison(baseline_images=["thick_lines"], extensions=["png"]) @@ -502,7 +502,7 @@ def test_thick_lines() -> None: pgm.add_plate([0.5, 0.5, 2, 1], label=r"$n = 1, \cdots, N$", shift=-0.1) # Render and save. - pgm.show() + pgm.render() @image_comparison(baseline_images=["weaklensing"], extensions=["png"]) @@ -524,11 +524,11 @@ def test_weaklensing() -> None: pgm.add_edge("x", "obs") pgm.add_edge("Sigma", "sigma") pgm.add_edge("sigma", "obs") - pgm.show() + pgm.render() -@image_comparison(baseline_images=["no_circles"], extensions=["png"]) -def test_no_circles() -> None: +@image_comparison(baseline_images=["wordy"], extensions=["png"]) +def test_wordy() -> None: pgm = daft.PGM() pgm.add_node("cloudy", r"cloudy", 3, 3, aspect=1.8) pgm.add_node("rain", r"rain", 2, 2, aspect=1.2) @@ -550,11 +550,11 @@ def test_no_circles() -> None: ) pgm.add_edge("rain", "wet") pgm.add_edge("sprinkler", "wet") - pgm.show() + pgm.render() @image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy() -> None: +def test_wordy2() -> None: pgm = daft.PGM() pgm.add_node("obs", r"$\epsilon^{obs}_n$", 2, 3, observed=True) pgm.add_node("true", r"$\epsilon^{true}_n$", 1, 3) @@ -571,4 +571,4 @@ def test_wordy() -> None: pgm.add_edge("sigma", "obs") pgm.add_plate([0.5, 2.25, 2, 1.25], label=r"galaxies $n$") pgm.add_plate([0.25, 1.75, 2.5, 2.75], label=r"patches $m$") - pgm.show() + pgm.render() From 8f85ecec4429e478bf26a7aa2333f6e370d2af58 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 16:55:37 -0500 Subject: [PATCH 06/15] fix wet_grass test Signed-off-by: David Fulford --- src/daft/pgm.py | 2 -- test/test_examples.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/daft/pgm.py b/src/daft/pgm.py index 14ffcbb..0213eea 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -514,6 +514,4 @@ def savefig(self, fname: str, *args: Any, **kwargs: Any) -> None: """ kwargs["bbox_inches"] = kwargs.get("bbox_inches", "tight") kwargs["dpi"] = kwargs.get("dpi", self._dpi) - if not self.figure: - self.render() self.figure.savefig(fname, *args, **kwargs) diff --git a/test/test_examples.py b/test/test_examples.py index 2d325e2..4dc53e6 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -527,8 +527,8 @@ def test_weaklensing() -> None: pgm.render() -@image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy() -> None: +@image_comparison(baseline_images=["wet_grass"], extensions=["png"]) +def test_wet_grass() -> None: pgm = daft.PGM() pgm.add_node("cloudy", r"cloudy", 3, 3, aspect=1.8) pgm.add_node("rain", r"rain", 2, 2, aspect=1.2) @@ -554,7 +554,7 @@ def test_wordy() -> None: @image_comparison(baseline_images=["wordy"], extensions=["png"]) -def test_wordy2() -> None: +def test_wordy() -> None: pgm = daft.PGM() pgm.add_node("obs", r"$\epsilon^{obs}_n$", 2, 3, observed=True) pgm.add_node("true", r"$\epsilon^{true}_n$", 1, 3) From ce7921583afda888fe62066209e6dbcb78311b84 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 17:00:02 -0500 Subject: [PATCH 07/15] final renaming for style Signed-off-by: David Fulford --- src/daft/__init__.py | 12 ++++++------ src/daft/edge.py | 8 ++++---- src/daft/{_exceptions.py => exceptions.py} | 0 src/daft/node.py | 10 +++++----- src/daft/pgm.py | 6 +++--- src/daft/plate.py | 6 +++--- src/daft/{_types.py => types.py} | 0 src/daft/{_utils.py => utils.py} | 4 ++-- 8 files changed, 23 insertions(+), 23 deletions(-) rename src/daft/{_exceptions.py => exceptions.py} (100%) rename src/daft/{_types.py => types.py} (100%) rename src/daft/{_utils.py => utils.py} (98%) diff --git a/src/daft/__init__.py b/src/daft/__init__.py index c8af2a4..979390c 100644 --- a/src/daft/__init__.py +++ b/src/daft/__init__.py @@ -2,13 +2,13 @@ from importlib.metadata import version as get_distribution -from . import node, edge, pgm, plate, _exceptions, _utils, _types +from . import exceptions, node, edge, pgm, plate, types, utils from .pgm import PGM from .node import Node from .edge import Edge from .plate import Plate, Text -from ._exceptions import SameLocationError -from ._utils import _RenderingContext, _pop_multiple +from .exceptions import SameLocationError +from .utils import RenderingContext, _pop_multiple __version__ = get_distribution("daft") __all__ = [] @@ -16,6 +16,6 @@ __all__ += node.__all__ __all__ += edge.__all__ __all__ += plate.__all__ -__all__ += _exceptions.__all__ -__all__ += _utils.__all__ -__all__ += _types.__all__ +__all__ += exceptions.__all__ +__all__ += utils.__all__ +__all__ += types.__all__ diff --git a/src/daft/edge.py b/src/daft/edge.py index b7d7e9e..751c07f 100644 --- a/src/daft/edge.py +++ b/src/daft/edge.py @@ -8,8 +8,8 @@ from typing import Any, cast -from ._utils import _pop_multiple, _RenderingContext -from ._types import Tuple4F, PlotParams, LabelParams +from .utils import _pop_multiple, RenderingContext +from .types import Tuple4F, PlotParams, LabelParams class Edge: @@ -69,7 +69,7 @@ def __init__( self.plot_params = dict(plot_params) if plot_params else {} self.label_params = dict(label_params) if label_params else {} - def _get_coords(self, ctx: _RenderingContext) -> Tuple4F: + def _get_coords(self, ctx: RenderingContext) -> Tuple4F: """ Get the coordinates of the line. @@ -90,7 +90,7 @@ def _get_coords(self, ctx: _RenderingContext) -> Tuple4F: return x3, y3, x4 - x3, y4 - y3 - def render(self, ctx: _RenderingContext) -> FancyArrow | list[Line2D]: + def render(self, ctx: RenderingContext) -> FancyArrow | list[Line2D]: """ Render the edge in the given axes. diff --git a/src/daft/_exceptions.py b/src/daft/exceptions.py similarity index 100% rename from src/daft/_exceptions.py rename to src/daft/exceptions.py diff --git a/src/daft/node.py b/src/daft/node.py index 189d0bd..fd27ee4 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -10,8 +10,8 @@ from typing import Any, Literal, TypedDict, cast -from ._utils import _pop_multiple, _RenderingContext -from ._types import Tuple2F, CTX_Kwargs, PlotParams, LabelParams, Shape +from .utils import _pop_multiple, RenderingContext +from .types import Tuple2F, CTX_Kwargs, PlotParams, LabelParams, Shape class Node: @@ -135,7 +135,7 @@ def __init__( print("Warning: wrong shape value, set to ellipse instead") self.shape = "ellipse" - def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: + def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: """ Render the node. @@ -308,7 +308,7 @@ def render(self, ctx: _RenderingContext) -> Ellipse | Rectangle: return el - def get_frontier_coord(self, target_xy: Tuple2F, ctx: _RenderingContext, edge: 'Edge') -> Tuple2F: + def get_frontier_coord(self, target_xy: Tuple2F, ctx: RenderingContext, edge: 'Edge') -> Tuple2F: """ Get the coordinates of the point of intersection between the shape of the node and a line starting from the center of the node to an @@ -387,4 +387,4 @@ def get_frontier_coord(self, target_xy: Tuple2F, ctx: _RenderingContext, edge: ' from .edge import Edge -from ._exceptions import SameLocationError +from .exceptions import SameLocationError diff --git a/src/daft/pgm.py b/src/daft/pgm.py index 0213eea..3251669 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -13,8 +13,8 @@ from .node import Node from .edge import Edge from .plate import Plate, Text -from ._utils import _RenderingContext -from ._types import Tuple2F, NDArrayF, Shape, Position, CTX_Kwargs, PlotParams, LabelParams, RectParams +from .utils import RenderingContext +from .types import Tuple2F, NDArrayF, Shape, Position, CTX_Kwargs, PlotParams, LabelParams, RectParams # pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines @@ -111,7 +111,7 @@ def __init__( _origin = origin self.origin = tuple(origin) - self._ctx = _RenderingContext(CTX_Kwargs( + self._ctx = RenderingContext(CTX_Kwargs( shape=np.asarray(_shape, dtype=np.float64), origin=np.asarray(_origin, dtype=np.float64), grid_unit=grid_unit, diff --git a/src/daft/plate.py b/src/daft/plate.py index 3a302a9..819dd0c 100644 --- a/src/daft/plate.py +++ b/src/daft/plate.py @@ -10,8 +10,8 @@ from typing import Any, cast -from ._utils import _pop_multiple, _RenderingContext -from ._types import Tuple2F, Tuple4F, Position, RectParams +from .utils import _pop_multiple, RenderingContext +from .types import Tuple2F, Tuple4F, Position, RectParams class Plate: @@ -89,7 +89,7 @@ def __init__( self.position = position - def render(self, ctx: _RenderingContext) -> Rectangle: + def render(self, ctx: RenderingContext) -> Rectangle: """ Render the plate in the given axes. diff --git a/src/daft/_types.py b/src/daft/types.py similarity index 100% rename from src/daft/_types.py rename to src/daft/types.py diff --git a/src/daft/_utils.py b/src/daft/utils.py similarity index 98% rename from src/daft/_utils.py rename to src/daft/utils.py index f3e1443..95702a7 100644 --- a/src/daft/_utils.py +++ b/src/daft/utils.py @@ -7,10 +7,10 @@ from typing import Any, Literal, cast -from ._types import NDArrayF, CTX_Kwargs, LabelParams, AnyDict +from .types import NDArrayF, CTX_Kwargs, LabelParams, AnyDict -class _RenderingContext: +class RenderingContext: """ :param shape: The number of rows and columns in the grid. From 1a7d2bd6b24fa8f5d7df4c90e2fd95f0dc50674c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 22:01:57 +0000 Subject: [PATCH 08/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/daft/edge.py | 12 ++++---- src/daft/exceptions.py | 2 +- src/daft/node.py | 18 ++++++++---- src/daft/pgm.py | 62 +++++++++++++++++++++++------------------- src/daft/plate.py | 21 +++++++------- src/daft/types.py | 22 ++++++--------- src/daft/utils.py | 8 ++++-- test/example.py | 12 ++++++-- 8 files changed, 88 insertions(+), 69 deletions(-) diff --git a/src/daft/edge.py b/src/daft/edge.py index 751c07f..48b0da7 100644 --- a/src/daft/edge.py +++ b/src/daft/edge.py @@ -51,8 +51,8 @@ class Edge: def __init__( self, - node1: 'Node', - node2: 'Node', + node1: "Node", + node2: "Node", directed: bool = True, label: str | None = None, xoffset: float = 0, @@ -118,7 +118,7 @@ def render(self, ctx: RenderingContext) -> FancyArrow | list[Line2D]: textcoords="offset points", ha="center", va="center", - **cast(dict[str, Any], self.label_params) + **cast(dict[str, Any], self.label_params), ) if self.directed: @@ -140,7 +140,7 @@ def render(self, ctx: RenderingContext) -> FancyArrow | list[Line2D]: *self._get_coords(ctx), width=0, length_includes_head=True, - **cast(dict[str, Any], plot_params) + **cast(dict[str, Any], plot_params), ) # Add the arrow to the axes. @@ -159,9 +159,7 @@ def render(self, ctx: RenderingContext) -> FancyArrow | list[Line2D]: # Plot the line. line = ax.plot( - (x, x + dx), - (y, y + dy), - **cast(dict[str, Any], plot_params) + (x, x + dx), (y, y + dy), **cast(dict[str, Any], plot_params) ) return line diff --git a/src/daft/exceptions.py b/src/daft/exceptions.py index 4de4b98..753e615 100644 --- a/src/daft/exceptions.py +++ b/src/daft/exceptions.py @@ -11,7 +11,7 @@ class SameLocationError(Exception): The Edge object whose nodes are being added. """ - def __init__(self, edge: 'Edge') -> None: + def __init__(self, edge: "Edge") -> None: self.message = ( "Attempted to add edge between `{}` and `{}` but they " + "share the same location." diff --git a/src/daft/node.py b/src/daft/node.py index fd27ee4..b84fb63 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -122,11 +122,15 @@ def __init__( self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"] # Display parameters. - self.plot_params = cast(PlotParams, dict(plot_params) if plot_params else {}) + self.plot_params = cast( + PlotParams, dict(plot_params) if plot_params else {} + ) # Text parameters. self.offset = offset - self.label_params = cast(LabelParams | None, dict(label_params) if label_params else None) + self.label_params = cast( + LabelParams | None, dict(label_params) if label_params else None + ) # Shape if shape in ["ellipse", "rectangle"]: @@ -151,7 +155,10 @@ def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: plot_params = cast(PlotParams, dict(self.plot_params)) plot_params["lw"] = _pop_multiple( - cast(dict[str, Any], plot_params), ctx.line_width, "lw", "linewidth" + cast(dict[str, Any], plot_params), + ctx.line_width, + "lw", + "linewidth", ) plot_params["ec"] = plot_params["edgecolor"] = _pop_multiple( @@ -191,7 +198,6 @@ def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: if not fc_is_set: plot_params["fc"] = "k" - diameter = ctx.node_unit * scale if self.aspect is not None: aspect = self.aspect @@ -308,7 +314,9 @@ def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: return el - def get_frontier_coord(self, target_xy: Tuple2F, ctx: RenderingContext, edge: 'Edge') -> Tuple2F: + def get_frontier_coord( + self, target_xy: Tuple2F, ctx: RenderingContext, edge: "Edge" + ) -> Tuple2F: """ Get the coordinates of the point of intersection between the shape of the node and a line starting from the center of the node to an diff --git a/src/daft/pgm.py b/src/daft/pgm.py index 3251669..e72c8dc 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -14,7 +14,16 @@ from .edge import Edge from .plate import Plate, Text from .utils import RenderingContext -from .types import Tuple2F, NDArrayF, Shape, Position, CTX_Kwargs, PlotParams, LabelParams, RectParams +from .types import ( + Tuple2F, + NDArrayF, + Shape, + Position, + CTX_Kwargs, + PlotParams, + LabelParams, + RectParams, +) # pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines @@ -111,22 +120,24 @@ def __init__( _origin = origin self.origin = tuple(origin) - self._ctx = RenderingContext(CTX_Kwargs( - shape=np.asarray(_shape, dtype=np.float64), - origin=np.asarray(_origin, dtype=np.float64), - grid_unit=grid_unit, - node_unit=node_unit, - observed_style=observed_style, - alternate_style=alternate_style, - line_width=line_width, - node_ec=node_ec, - node_fc=node_fc, - plate_fc=plate_fc, - directed=directed, - aspect=aspect, - label_params=label_params, - dpi=dpi - )) + self._ctx = RenderingContext( + CTX_Kwargs( + shape=np.asarray(_shape, dtype=np.float64), + origin=np.asarray(_origin, dtype=np.float64), + grid_unit=grid_unit, + node_unit=node_unit, + observed_style=observed_style, + alternate_style=alternate_style, + line_width=line_width, + node_ec=node_ec, + node_fc=node_fc, + plate_fc=plate_fc, + directed=directed, + aspect=aspect, + label_params=label_params, + dpi=dpi, + ) + ) def __enter__(self) -> "PGM": return self @@ -351,7 +362,9 @@ def add_plate( self._plates.append(_plate) - def add_text(self, x: float, y: float, label: str, fontsize: float | None = None) -> None: + def add_text( + self, x: float, y: float, label: str, fontsize: float | None = None + ) -> None: """ A subclass of plate to writing text using grid coordinates. Any ``**kwargs`` are passed through to :class:`PGM.Plate`. @@ -370,12 +383,7 @@ def add_text(self, x: float, y: float, label: str, fontsize: float | None = None """ - text = Text( - x=x, - y=y, - label=label, - fontsize=fontsize - ) + text = Text(x=x, y=y, label=label, fontsize=fontsize) self._plates.append(text) return None @@ -400,8 +408,7 @@ def get_max(maxsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: if isinstance(patch, Ellipse): maxsize = np.maximum( maxsize, - patch.center - + np.array([patch.width, patch.height]) / 2, + patch.center + np.array([patch.width, patch.height]) / 2, dtype=np.float64, ) elif isinstance(patch, Rectangle): @@ -417,8 +424,7 @@ def get_min(minsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: if isinstance(patch, Ellipse): minsize = np.minimum( minsize, - patch.center - - np.array([patch.width, patch.height]) / 2, + patch.center - np.array([patch.width, patch.height]) / 2, dtype=np.float64, ) elif isinstance(patch, Rectangle): diff --git a/src/daft/plate.py b/src/daft/plate.py index 819dd0c..68d89de 100644 --- a/src/daft/plate.py +++ b/src/daft/plate.py @@ -101,8 +101,12 @@ def render(self, ctx: RenderingContext) -> Rectangle: shift = np.array([0, self.shift], dtype=np.float64) rect = np.atleast_1d(np.asarray(self.rect, dtype=np.float64)) - bottom_left = np.asarray(ctx.convert(*(rect[:2] + shift)), dtype=np.float64) - top_right = np.asarray(ctx.convert(*(rect[:2] + rect[2:])), dtype=np.float64) + bottom_left = np.asarray( + ctx.convert(*(rect[:2] + shift)), dtype=np.float64 + ) + top_right = np.asarray( + ctx.convert(*(rect[:2] + rect[2:])), dtype=np.float64 + ) rect = np.concatenate([bottom_left, top_right - bottom_left]) if self.rect_params is not None: @@ -110,9 +114,7 @@ def render(self, ctx: RenderingContext) -> Rectangle: else: rect_params = {} - rect_params["ec"] = _pop_multiple( - rect_params, "k", "ec", "edgecolor" - ) + rect_params["ec"] = _pop_multiple(rect_params, "k", "ec", "edgecolor") rect_params["fc"] = _pop_multiple( rect_params, ctx.plate_fc, "fc", "facecolor" ) @@ -121,10 +123,7 @@ def render(self, ctx: RenderingContext) -> Rectangle: ) rectangle = Rectangle( - xy=(rect[0], rect[1]), - width=rect[2], - height=rect[3], - **rect_params + xy=(rect[0], rect[1]), width=rect[2], height=rect[3], **rect_params ) ax.add_artist(rectangle) @@ -199,7 +198,9 @@ class Text(Plate): """ - def __init__(self, x: float, y: float, label: str, fontsize: float | None = None) -> None: + def __init__( + self, x: float, y: float, label: str, fontsize: float | None = None + ) -> None: self.rect = (x, y, 0.0, 0.0) self.label = label self.fontsize = fontsize diff --git a/src/daft/types.py b/src/daft/types.py index 9a2d926..7a4a976 100644 --- a/src/daft/types.py +++ b/src/daft/types.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray, ArrayLike from typing import Any, Annotated, Literal, TypeVar, TypedDict -T = TypeVar('T') +T = TypeVar("T") NDArrayF = NDArray[np.float64] NDArrayI = NDArray[np.int64] @@ -29,7 +29,7 @@ "middle right", "top left", "top center", - "top right" + "top right", ] @@ -51,17 +51,13 @@ class LabelParams(TypedDict): ma: str -RectParams = TypedDict( - "RectParams", { - "ec": str, - "edgecolor": str, - "fc": str, - "facecolor": str, - "lw": str, - "linewidth": str, - }, - total=False -) +class RectParams(TypedDict, total=False): + ec: str + edgecolor: str + fc: str + facecolor: str + lw: str + linewidth: str class CTX_Kwargs(TypedDict): diff --git a/src/daft/utils.py b/src/daft/utils.py index 95702a7..04657b5 100644 --- a/src/daft/utils.py +++ b/src/daft/utils.py @@ -96,7 +96,9 @@ def __init__(self, kwargs: CTX_Kwargs) -> None: self.plate_fc = kwargs.get("plate_fc", "w") self.directed = kwargs.get("directed", True) self.aspect = kwargs.get("aspect", 1.0) - self.label_params = cast(LabelParams, kwargs.get("label_params", {}) or {}) + self.label_params = cast( + LabelParams, kwargs.get("label_params", {}) or {} + ) self.dpi = kwargs.get("dpi", None) @@ -164,7 +166,9 @@ def convert(self, x: float, y: float) -> tuple[float, float]: Convert from model coordinates to plot coordinates. """ - return self.grid_unit * (x - self.origin[0]), self.grid_unit * (y - self.origin[1]) + return self.grid_unit * (x - self.origin[0]), self.grid_unit * ( + y - self.origin[1] + ) def _pop_multiple(_dict: AnyDict, default: Any, *args: str) -> Any: diff --git a/test/example.py b/test/example.py index 703fdda..dad82f3 100644 --- a/test/example.py +++ b/test/example.py @@ -10,15 +10,21 @@ x_offset, y_offset = 0, 0 pgm.add_node("z1", r"$z_1$", 1 + x_offset, 1 + y_offset) pgm.add_node("z2", r"$z_2$", 2 + x_offset, 1 + y_offset) -pgm.add_node("z...", r"$\cdots$", 3 + x_offset, 1 + y_offset, plot_params=no_circle) +pgm.add_node( + "z...", r"$\cdots$", 3 + x_offset, 1 + y_offset, plot_params=no_circle +) pgm.add_node("zT", r"$z_T$", 4 + x_offset, 1 + y_offset) pgm.add_node("y1", r"$x_1$", 1 + x_offset, 2.3 + y_offset) pgm.add_node("y2", r"$x_2$", 2 + x_offset, 2.3 + y_offset) -pgm.add_node("y...", r"$\cdots$", 3 + x_offset, 2.3 + y_offset, plot_params=no_circle) +pgm.add_node( + "y...", r"$\cdots$", 3 + x_offset, 2.3 + y_offset, plot_params=no_circle +) pgm.add_node("yT", r"$x_T$", 4 + x_offset, 2.3 + y_offset) pgm.add_node("x1", r"$x_1$", 1 + x_offset, y_offset) pgm.add_node("x2", r"$x_2$", 2 + x_offset, y_offset) -pgm.add_node("x...", r"$\cdots$", 3 + x_offset, y_offset, plot_params=no_circle) +pgm.add_node( + "x...", r"$\cdots$", 3 + x_offset, y_offset, plot_params=no_circle +) pgm.add_node("xT", r"$x_T$", 4 + x_offset, y_offset) # Edges. From 32998c1b2a4f02dec46745c97caac84ff09b7c48 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 17:06:29 -0500 Subject: [PATCH 09/15] fix use of Union | in expressions Signed-off-by: David Fulford --- src/daft/types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/daft/types.py b/src/daft/types.py index 9a2d926..8eeaa27 100644 --- a/src/daft/types.py +++ b/src/daft/types.py @@ -4,7 +4,7 @@ import numpy as np from numpy.typing import NDArray, ArrayLike -from typing import Any, Annotated, Literal, TypeVar, TypedDict +from typing import Any, Annotated, Literal, TypeVar, TypedDict, Optional, Union T = TypeVar('T') @@ -77,8 +77,8 @@ class CTX_Kwargs(TypedDict): plate_fc: str directed: bool aspect: float - label_params: LabelParams | None - dpi: int | None + label_params: Optional[LabelParams] + dpi: Optional[int] -AnyDict = dict[str, Any] | PlotParams | LabelParams | RectParams | CTX_Kwargs +AnyDict = Union[dict[str, Any], PlotParams, LabelParams, RectParams, CTX_Kwargs] From 521f8b10cbdbe008447b4240104a799b52bc9fb7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 22:07:35 +0000 Subject: [PATCH 10/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/daft/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/daft/types.py b/src/daft/types.py index e047e66..2133688 100644 --- a/src/daft/types.py +++ b/src/daft/types.py @@ -77,4 +77,6 @@ class CTX_Kwargs(TypedDict): dpi: Optional[int] -AnyDict = Union[dict[str, Any], PlotParams, LabelParams, RectParams, CTX_Kwargs] +AnyDict = Union[ + dict[str, Any], PlotParams, LabelParams, RectParams, CTX_Kwargs +] From 4af6e4f18c37305c9e3b1ead147608b86ab467c8 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 17:11:06 -0500 Subject: [PATCH 11/15] fix use of Union | in all typing Signed-off-by: David Fulford --- src/daft/edge.py | 8 ++++---- src/daft/node.py | 16 ++++++++-------- src/daft/pgm.py | 42 +++++++++++++++++++++--------------------- src/daft/plate.py | 18 +++++++++--------- src/daft/utils.py | 6 +++--- 5 files changed, 45 insertions(+), 45 deletions(-) diff --git a/src/daft/edge.py b/src/daft/edge.py index 48b0da7..b0a03cf 100644 --- a/src/daft/edge.py +++ b/src/daft/edge.py @@ -6,7 +6,7 @@ from matplotlib.lines import Line2D from matplotlib.patches import FancyArrow -from typing import Any, cast +from typing import Any, cast, Optional from .utils import _pop_multiple, RenderingContext from .types import Tuple4F, PlotParams, LabelParams @@ -54,11 +54,11 @@ def __init__( node1: "Node", node2: "Node", directed: bool = True, - label: str | None = None, + label: Optional[str] = None, xoffset: float = 0, yoffset: float = 0.1, - plot_params: PlotParams | None = None, - label_params: LabelParams | None = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, ) -> None: self.node1 = node1 self.node2 = node2 diff --git a/src/daft/node.py b/src/daft/node.py index b84fb63..eceba78 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -8,10 +8,10 @@ import numpy as np -from typing import Any, Literal, TypedDict, cast +from typing import Any, cast, Optional from .utils import _pop_multiple, RenderingContext -from .types import Tuple2F, CTX_Kwargs, PlotParams, LabelParams, Shape +from .types import Tuple2F, PlotParams, LabelParams, Shape class Node: @@ -78,14 +78,14 @@ def __init__( x: float, y: float, scale: float = 1.0, - aspect: float | None = None, + aspect: Optional[float] = None, observed: bool = False, fixed: bool = False, alternate: bool = False, offset: Tuple2F = (0.0, 0.0), - fontsize: float | None = None, - plot_params: PlotParams | None = None, - label_params: LabelParams | None = None, + fontsize: Optional[float] = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, shape: Shape = "ellipse", ) -> None: # Check Node style. @@ -114,7 +114,7 @@ def __init__( if self.fixed: self.scale /= 6.0 if aspect is not None: - self.aspect: float | None = float(aspect) + self.aspect: Optional[float] = float(aspect) else: self.aspect = aspect @@ -129,7 +129,7 @@ def __init__( # Text parameters. self.offset = offset self.label_params = cast( - LabelParams | None, dict(label_params) if label_params else None + Optional[LabelParams], dict(label_params) if label_params else None ) # Shape diff --git a/src/daft/pgm.py b/src/daft/pgm.py index e72c8dc..cd0d7a0 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -4,11 +4,11 @@ # TODO: should Text be added? import matplotlib.pyplot as plt -from matplotlib.patches import Ellipse, FancyArrow, Rectangle +from matplotlib.patches import Ellipse, Rectangle import numpy as np -from typing import Any, Literal, cast +from typing import Any, Optional from .node import Node from .edge import Edge @@ -84,8 +84,8 @@ class PGM: def __init__( self, - shape: Tuple2F | None = None, - origin: Tuple2F | None = None, + shape: Optional[Tuple2F] = None, + origin: Optional[Tuple2F] = None, grid_unit: float = 2.0, node_unit: float = 1.0, observed_style: str = "shaded", @@ -96,8 +96,8 @@ def __init__( plate_fc: str = "w", directed: bool = True, aspect: float = 1.0, - label_params: LabelParams | None = None, - dpi: int | None = None, + label_params: Optional[LabelParams] = None, + dpi: Optional[int] = None, ) -> None: self._nodes: dict[str, Node] = {} self._edges: list[Edge] = [] @@ -152,14 +152,14 @@ def add_node( x: float = 0, y: float = 0, scale: float = 1.0, - aspect: float | None = None, + aspect: Optional[float] = None, observed: bool = False, fixed: bool = False, alternate: bool = False, offset: Tuple2F = (0, 0), - fontsize: float | None = None, - plot_params: PlotParams | None = None, - label_params: LabelParams | None = None, + fontsize: Optional[float] = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, shape: Shape = "ellipse", ) -> Node: """ @@ -246,12 +246,12 @@ def add_edge( self, name1: str, name2: str, - directed: bool | None = None, + directed: Optional[bool] = None, xoffset: float = 0.0, yoffset: float = 0.1, - label: str | None = None, - plot_params: PlotParams | None = None, - label_params: LabelParams | None = None, + label: Optional[str] = None, + plot_params: Optional[PlotParams] = None, + label_params: Optional[LabelParams] = None, **kwargs: dict[str, Any], # pylint: disable=unused-argument ) -> Edge: """ @@ -308,13 +308,13 @@ def add_edge( def add_plate( self, plate: Plate, - label: str | None = None, + label: Optional[str] = None, label_offset: Tuple2F = (5, 5), shift: float = 0, position: Position = "bottom left", - fontsize: float | None = None, - rect_params: RectParams | None = None, - bbox: bool | None = None, + fontsize: Optional[float] = None, + rect_params: Optional[RectParams] = None, + bbox: Optional[bool] = None, ) -> None: """ Add a :class:`Plate` object to the model. @@ -363,7 +363,7 @@ def add_plate( self._plates.append(_plate) def add_text( - self, x: float, y: float, label: str, fontsize: float | None = None + self, x: float, y: float, label: str, fontsize: Optional[float] = None ) -> None: """ A subclass of plate to writing text using grid coordinates. Any @@ -388,7 +388,7 @@ def add_text( return None - def render(self, dpi: int | None = None) -> plt.Axes: + def render(self, dpi: Optional[int] = None) -> plt.Axes: """ Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in the model. This will create a new figure with the correct dimensions @@ -492,7 +492,7 @@ def ax(self) -> plt.Axes: """Axes as a property.""" return self._ctx.ax() - def show(self, *args: Any, dpi: int | None = None, **kwargs: Any) -> None: + def show(self, *args: Any, dpi: Optional[int] = None, **kwargs: Any) -> None: """ Wrapper on :class:`PGM.render()` that calls `matplotlib.show()` immediately after. diff --git a/src/daft/plate.py b/src/daft/plate.py index 68d89de..bf4e311 100644 --- a/src/daft/plate.py +++ b/src/daft/plate.py @@ -8,7 +8,7 @@ import numpy as np -from typing import Any, cast +from typing import Any, cast, Optional from .utils import _pop_multiple, RenderingContext from .types import Tuple2F, Tuple4F, Position, RectParams @@ -55,13 +55,13 @@ class Plate: def __init__( self, rect: Tuple4F, - label: str | None = None, + label: Optional[str] = None, label_offset: Tuple2F = (5, 5), shift: float = 0, position: Position = "bottom left", - fontsize: float | None = None, - rect_params: RectParams | None = None, - bbox: dict[str, Any] | None = None, + fontsize: Optional[float] = None, + rect_params: Optional[RectParams] = None, + bbox: dict[str, Optional[Any]] = None, ) -> None: self.rect = rect self.label = label @@ -69,17 +69,17 @@ def __init__( self.shift = shift if fontsize is not None: - self.fontsize: float | None = fontsize + self.fontsize: Optional[float] = fontsize else: self.fontsize = mpl.rcParams["font.size"] if rect_params is not None: - self.rect_params: RectParams | None = rect_params + self.rect_params: Optional[RectParams] = rect_params else: self.rect_params = None if bbox is not None: - self.bbox: dict[str, Any] | None = dict(bbox) + self.bbox: dict[str, Optional[Any]] = dict(bbox) # Set the awful default blue color to transparent if "fc" not in self.bbox.keys(): @@ -199,7 +199,7 @@ class Text(Plate): """ def __init__( - self, x: float, y: float, label: str, fontsize: float | None = None + self, x: float, y: float, label: str, fontsize: Optional[float] = None ) -> None: self.rect = (x, y, 0.0, 0.0) self.label = label diff --git a/src/daft/utils.py b/src/daft/utils.py index 04657b5..7bbe520 100644 --- a/src/daft/utils.py +++ b/src/daft/utils.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -from typing import Any, Literal, cast +from typing import Any, cast, Optional from .types import NDArrayF, CTX_Kwargs, LabelParams, AnyDict @@ -103,8 +103,8 @@ def __init__(self, kwargs: CTX_Kwargs) -> None: self.dpi = kwargs.get("dpi", None) # Initialize the figure to ``None`` to handle caching later. - self._figure: plt.Figure | None = None - self._ax: plt.Axis | None = None + self._figure: Optional[plt.Figure] = None + self._ax: Optional[plt.Axis] = None def reset_shape(self, shape: NDArrayF, adj_origin: bool = False) -> None: """Reset the shape and figure size.""" From 6ad54873556a4c756d7b6641538da1277f58419a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 22:11:26 +0000 Subject: [PATCH 12/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/daft/pgm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/daft/pgm.py b/src/daft/pgm.py index cd0d7a0..55a977e 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -492,7 +492,9 @@ def ax(self) -> plt.Axes: """Axes as a property.""" return self._ctx.ax() - def show(self, *args: Any, dpi: Optional[int] = None, **kwargs: Any) -> None: + def show( + self, *args: Any, dpi: Optional[int] = None, **kwargs: Any + ) -> None: """ Wrapper on :class:`PGM.render()` that calls `matplotlib.show()` immediately after. From 40d0f833cfbb1c0b8fab76c06a823e442f56a9c6 Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 17:14:41 -0500 Subject: [PATCH 13/15] missed the Union types... Signed-off-by: David Fulford --- src/daft/edge.py | 4 ++-- src/daft/node.py | 8 ++++---- src/daft/pgm.py | 8 ++++---- src/daft/types.py | 5 ++--- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/daft/edge.py b/src/daft/edge.py index b0a03cf..21ede14 100644 --- a/src/daft/edge.py +++ b/src/daft/edge.py @@ -6,7 +6,7 @@ from matplotlib.lines import Line2D from matplotlib.patches import FancyArrow -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Union from .utils import _pop_multiple, RenderingContext from .types import Tuple4F, PlotParams, LabelParams @@ -90,7 +90,7 @@ def _get_coords(self, ctx: RenderingContext) -> Tuple4F: return x3, y3, x4 - x3, y4 - y3 - def render(self, ctx: RenderingContext) -> FancyArrow | list[Line2D]: + def render(self, ctx: RenderingContext) -> Union[FancyArrow, list[Line2D]]: """ Render the edge in the given axes. diff --git a/src/daft/node.py b/src/daft/node.py index eceba78..cf32236 100644 --- a/src/daft/node.py +++ b/src/daft/node.py @@ -8,7 +8,7 @@ import numpy as np -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Union from .utils import _pop_multiple, RenderingContext from .types import Tuple2F, PlotParams, LabelParams, Shape @@ -139,7 +139,7 @@ def __init__( print("Warning: wrong shape value, set to ellipse instead") self.shape = "ellipse" - def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: + def render(self, ctx: RenderingContext) -> Union[Ellipse, Rectangle]: """ Render the node. @@ -229,7 +229,7 @@ def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: # Draw the background ellipse. if self.shape == "ellipse": - bg: Ellipse | Rectangle = Ellipse( + bg: Union[Ellipse, Rectangle] = Ellipse( xy=ctx.convert(self.x, self.y), width=w, height=h, @@ -266,7 +266,7 @@ def render(self, ctx: RenderingContext) -> Ellipse | Rectangle: plot_params["fc"] = "none" if self.shape == "ellipse": - el: Ellipse | Rectangle = Ellipse( + el: Union[Ellipse, Rectangle] = Ellipse( xy=ctx.convert(self.x, self.y), width=diameter * aspect, height=diameter, diff --git a/src/daft/pgm.py b/src/daft/pgm.py index cd0d7a0..e3bd8e7 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -8,7 +8,7 @@ import numpy as np -from typing import Any, Optional +from typing import Any, Optional, Union from .node import Node from .edge import Edge @@ -404,7 +404,7 @@ def render(self, dpi: Optional[int] = None) -> plt.Axes: else: self._ctx.dpi = dpi - def get_max(maxsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: + def get_max(maxsize: NDArrayF, patch: Union[Ellipse, Rectangle]) -> NDArrayF: if isinstance(patch, Ellipse): maxsize = np.maximum( maxsize, @@ -420,7 +420,7 @@ def get_max(maxsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: ) return maxsize - def get_min(minsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: + def get_min(minsize: NDArrayF, patch: Union[Ellipse, Rectangle]) -> NDArrayF: if isinstance(patch, Ellipse): minsize = np.minimum( minsize, @@ -440,7 +440,7 @@ def get_min(minsize: NDArrayF, patch: Ellipse | Rectangle) -> NDArrayF: maxsize = np.copy(self._ctx.origin) for plate in self._plates: - artist: Ellipse | Rectangle = plate.render(self._ctx) + artist: Union[Ellipse, Rectangle] = plate.render(self._ctx) maxsize = get_max(maxsize, artist) for name in self._nodes: diff --git a/src/daft/types.py b/src/daft/types.py index 2133688..dd8f932 100644 --- a/src/daft/types.py +++ b/src/daft/types.py @@ -3,15 +3,14 @@ __all__: list[str] = [] import numpy as np -from numpy.typing import NDArray, ArrayLike -from typing import Any, Annotated, Literal, TypeVar, TypedDict, Optional, Union +from numpy.typing import NDArray +from typing import Any, Literal, TypeVar, TypedDict, Optional, Union T = TypeVar("T") NDArrayF = NDArray[np.float64] NDArrayI = NDArray[np.int64] -# Tuple2 = tuple[T, T] | list[T] Tuple2 = tuple[T, T] Tuple4 = tuple[T, T, T, T] From 7ab8bd9757ceaf179702f3c7c1d609e81383ac27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 22:14:51 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/daft/pgm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/daft/pgm.py b/src/daft/pgm.py index 5394b6d..0b802f8 100644 --- a/src/daft/pgm.py +++ b/src/daft/pgm.py @@ -404,7 +404,9 @@ def render(self, dpi: Optional[int] = None) -> plt.Axes: else: self._ctx.dpi = dpi - def get_max(maxsize: NDArrayF, patch: Union[Ellipse, Rectangle]) -> NDArrayF: + def get_max( + maxsize: NDArrayF, patch: Union[Ellipse, Rectangle] + ) -> NDArrayF: if isinstance(patch, Ellipse): maxsize = np.maximum( maxsize, @@ -420,7 +422,9 @@ def get_max(maxsize: NDArrayF, patch: Union[Ellipse, Rectangle]) -> NDArrayF: ) return maxsize - def get_min(minsize: NDArrayF, patch: Union[Ellipse, Rectangle]) -> NDArrayF: + def get_min( + minsize: NDArrayF, patch: Union[Ellipse, Rectangle] + ) -> NDArrayF: if isinstance(patch, Ellipse): minsize = np.minimum( minsize, From 3e4b2286e702073bb7d0091667af6c52166522db Mon Sep 17 00:00:00 2001 From: David Fulford Date: Thu, 25 Jul 2024 17:23:09 -0500 Subject: [PATCH 15/15] add wet_grass.png to tracking Signed-off-by: David Fulford --- .../baseline_images/test_examples/wet_grass.png | Bin 0 -> 16487 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/baseline_images/test_examples/wet_grass.png diff --git a/test/baseline_images/test_examples/wet_grass.png b/test/baseline_images/test_examples/wet_grass.png new file mode 100644 index 0000000000000000000000000000000000000000..2f55464a536c6a577fcde296ba3922ad940e526c GIT binary patch literal 16487 zcmZvDcRbd8{O+AfD0_w?*@UDhd(SemLdhy4iO9$*WMyv^S!JYDR%RJll@Se*kU}Dv zNu2BZob&pf^T+AcQ=YiL<1^mxYkeXOPH0k7u~U&qBx-FfH6#2zi~ss5$nobbTCOSl zb<|Vc-1DT{IZtmJ4+qk58_!D@+&nKh*>YcT@VM;ccJYvyjF^-tx1*=$rOU^}#a;jB z2gKYw&Wm&L?f1e(C@*Q9x=bR`+7SP3DO9}gL?S8nYpW?6``rF{$J?CA>{EEq)15oD z5@>Wx)QWPs%?(5k3TJ*uG`Ed@`60=iAurO( zY;*1$IR^)axus?E^XE##IftCt_%H6^77^KL_R#7@iip{-FE6(8^6~8x5~36m62fQs zI(LStP;4V>x0g(QckBM6N8P#?TwJ1Q_8fcu{ynRhm>3;}IbX@X=Z^!22EWgbw^e&i zF^E|=-0VsdADb*rB}o9|=Zv|P`-e}4xxHT5OVHZEPUG*zAiZh84V`q>H- z1>zg>Vq!GCy}j{W6OM!BLdU#kcUfCoudc1-UjF{3=iNKg_ifL8cBWTXsA*`BUk?eP zXJduN>y4Wm9M`U0Tm3uNxP`QpeAn~zgLc2($j&|Q=`lFq-Z9}Z z^z8DEaGE_$c$0+mbc!82cHG<}%VO6S8&+RmZ!uRgMoJU&$8 z7407oKuHQ%9S!%K`z-GjZtT!vd+^`#CpY_Z=QRBM{K{=xsS>q#r)i?++4Y8;&*CZa zi;A={U!Jo=N4>&Ny4s&V|CZI-UzzM$u;1dOaz;kRG)?6EriA+lS%#j5b-)tymb;OR z0*U6LrruhIn=n+m_wL>M`sIt#$7Hh`*={f3m4^KxQu&h+# zWhjz;z@B4SbrI7LURinMnAddZ$@?dYWb-v>f6x?@F+^?)3=FKkx*|ryDH9y{Zw=qE z*e5N`$`!PJ@I;neB*r%%Y6i&7XNVq!h?s{#X5bnO*oOG^&op7xTb4r{gFt`DZIWGG6CSr^+{p|L4YB?;y0hd*r>2UGzr8E9Z13n$j@v7r zTsczfO{D_3PVr zY6&tW#>Ub-tF2AK6K=VCzB{Y;v0WPl>2zb@9r5gh1m!bl&TI>|sJx)f&(FUt_}r`X z9M7TA=J1oZj=1XFp1bANu$4`=S z>f;p_75!SZxiS*8_^bIn_rU{E9v+<&C-~2wKTq~NT34;V!o_U5CnuzDKHo^_l!ZlKd&2&%fJe1ge)5o# z)6>mIzc&>sd^q$lN=;KUPcv?>u>7YDEH%fEPxy{p_&~z@1=c=#B)+(~sFNJDv3^uV zMMdJ=E4Ia-pIAxH-8lo-rUX8G_;B{^z2hmdE4Vi1y9F!bPjSH0bFa%hr}`Ki1oX3R za!6kktnlEHy|SyQsAzTOX?Lf*zI%%Nkt3CT2PKPX7|ql@JUomu6-82Xu&`YhXFf7! zW@ZZSkrfuEq6#C!*3$WO`0(LgybWrl>9No4?OG_4sE`Fehikk<440c}6+<6CK3I8i zxYO-Zr7MR;g|o@nskP%#ij41+OlmREq2tv?3$pdvfq@E4RNJwIOcbu*Ytr)%AF5ki zUcP)OC@LCiC1xgXPC}#U85lU=J}+2ma~vBAo456I{Z&JA^9Zz}+X?~fi~=Vfl-)2~ z&yZR+T^_sO^z@8RPcQuaIGQhLdFE8)`z-mZ8XHu&FETXQ{=c=u z9Pai&TA3eGo>ty{1Is|@F z)NOA3Eh{T4lvqa7K%=a#tu-@GHH;~I=KkRGrO~GCWJ+n`)(^{OVmPF=co_^)i>A<4 zl2cPPu(Q#$cskqL1uIOBa&rfhdfrpbxJ#1XNGIj zNOYuQ?qi`k*}7F1hYPJ61BFu@jJfozm=n4t?CtDSQJhVR);~Ty-_hHfz`BKW^w_aJ zytpp2nONGw>0_-ESlUzFikqVDGp*6ANlM(@BwjiS$IfJ-McgxOvhBE*7Pr2>{)P`$ z#v(SN$kKSj2^*VzyLazSC@nqEG4Y#fWIRbgpMr$qRJ52OQA~?kn=b9PcQ-q`k&aGO z=U+Mh)%d37W+4TIGX=jf`kw#RyrN=b`7`uPFFr#-FmiTw4r{AJ(@EX)3C5$ zcXD<%Qc?=mdHg2xaI~kF*A$kAaLReS+mKE?h9WgJwZFvdkw-!CBYzAJ(E))rO3P-@ z*eFP)=8tn(c4gZxzE}C#-6m(NXUrfEd-&^bL<`us14r9&gJ@+{b3rk`3ONhsNR)*+S zpOs;6dREr(_V#w2C9^}sn==FC{90LQLr;y~BYWiw(UZKjDpdUZj@_Wy6NLi6EhSan zZ)kpV@Y>S2=@xV=>?)hpU!!b?otYZ`&ej^44Y+8jsjCx(9F3qUY$t1a`Lss;;ia`0 zV2bRZjSRfaKGhpNz+nS>jAAY=P8#;TeVd1(Yp9}fy{6fQ&gFJ7OH|{BckkYL{`#Vv za&Kc`M`B{4zVq1u8mMyvOb5r0r&kvkFgEXASXd~!x-#41&gssnNosFux{j&S z!SLKeF<>vTRR~y120+lqR2=p7^Gh@k2c3$&-c6yDAD&$n2+r-8Pe@3p_WvVEpz2@0E-^7NDWS}sIB}w^ zvhqen#o-GVF077}mX_xHZ*cAC?5x5~rzD)B!k?u4KL}gh$cO>6yE>9B=SxroQgBGf zR$$0K*&#b7htO2Nq z_SfgDbED{9m)zZpQtcYVT+tU$!;{aXY4O&z)9#fsNYo$bSZ$3^Q{n&z~`3 zr)!?~^*vGh6C6wmJbb;l{VM9BsCgL)jZbByApnma4#W&7y9XFkCTK%WO3Hj~cKSv@ zuG`tOXGy0|pH}#jVI}tc?aMkE`l96XMLbT|cs~03`|F&JF1&k}1aPE^H~hV@@Tc~H zl~_2~&Zgb}<~r$_nWn;T36QbbqP9b)Ug%_x(SFi}|s%%7EJoJVp2I+b4B#NQT%g7=EErF@w~)CGA1m&Lq^g zzONlg9iKkk0x7v)ey7lJ;s)A5UvoHZ3InxOyFL^3UIy)-bvfjE zc`VqLX*664dP)4pb5ZgQt*kid>4zJGvM;#03JD8GMMP-Cu&Wf@xPtMNn9&u-;+I~DsE<;aGPf$oGQk;v0MFW_e2b+Ab!i6

Q>~M+a%G(;{M8#7fn_FjLJ zt=R2mS$SdiDRXlJQ`7(ErLL|nqR??!e`R*qX|g9LIWbWMof}kKq$X%nq0vn>u`>Pl zY#sVZ!KtU`dcjN`*A}OM_73qF1%3mtzdzTwnYE2mC1AzKEIjKcSNlDxY)Uu$uT1FW46ui>l{ln8mxCa4@U%Ys+ zKF^+8-kL1nc8dD)qd2*)&dyxW=Y(t5whY%_t!inYcA&C%ee2kgqQ%Qsd=BWkfaQQi zd}E{H>ukmB+js6bfe)$~FNV1Q7ylfrh!z)TiJroAt`CP5+UQ#9`3`)u4P2ksEVvo> zsIk(fnH+;BW>sx4w)FL@JzD-fOxQaUV@me0YTL)@V3kCR0XpV05@cFu6_As=-SFod z2m6r==8?O?NBkG412+DAS6ghn828?3S0BsC&v%H0j#=z@|K7}rerBwj7vLZ55UY=_ zcQYSvEb2A$A%P$(7cX9vavhX9Qlh2KURT}PM-|q@P*Q7VJYd-vC|`l~+3yka-?m#p zi_I@6P(O6;)fd)2eqrZ&@5`6-00D&$1n$|p_rw8MoNeDBZlb{-z=pG+wyqtf~`T$$O7C!yzM42L%O z4lybP?`J>P>#~Yj{nsaLZFx4s{0^N|$CC5)@ey2lf8-qY(p<*MWN%(SiaZ0wHZ-sb z=Yc3pVaxXt;f*<-7Xi)|;!J$lQ-R@=eoo&Ta7XS(nu##F2td3{v)~$hO zE_9_HNNF&4tGvR`%lm8V_Vn`2ygFGo4-d)Hs6?@Q<#iw<6`W!F{whNixOs8?&t;1S z|D#J&Z^Jd>_P&lgW0$NS!kiW{3F@{UUkfzU1M;BFP1jRq)9l!MgPMtr&HS9n!*ttg z8RdM24z?->3%oXWVW$%ljIR5wJNuDTa7f|g*wxLAfOlt2sHk;@*4O<98-p7AY81d= z?<$9o6Y2~4G!%q@q8Ncviwfb(%nJOjAB{E6_y=+x9`sgTfNDK2N~8OJFyXQ&Y+(PcmX`3A$->{`?NCzQc9Vem80!aSkZ^ z`^z169ZUvgxL+RBoc;hyX8dJm4*a z0D{^72eu$27Q8H!eyZ|h+Wn&xh5HIBH&23=)Abe1WJhyz2=Ss2@t8O{^`@q#bpAS= zJJ)w}_aTrWA(Nzxa+CX_zI+UnpgW~%**?n3TlUENCHt++9tRX4YAY2bWzo;FAA6`( zh8j24qEYeg$oVElM3_Fb(&lV41Q;YL4$(vg2L~5_R*n2t{cZZdT)Cm4f!ML1s@%?( zwhcrGch?yeB?AurL>s8_o;wlnce#LQbWmQLpbd%&T;7hSpBt{#Jf10)vwbHEQK^Zt zj~jKERWL9xIQHCC&dBiK#RTx-4Koa%goHK;Ft&_0k>LH`n(8o zk&CA!oCauAA(i7-n(s+Z%D6)q&O{<)GVovG}t~LLlHYMFHEsa6HYh+($mwE(#w~qQF#j*3yg$NZ+C>LWw1(AS&l3iUpL4NfH{a&KT zlrC0TiM0i7{tE|}C=$x8IUYjq{pa@tp?Q+HWo0=?NtFtZFMO1|xmTW@#7(z8=kjO3 z%zB@g*oo6N1E<(StAY8=YdnQM`&%h`O($pa*c!7y6Pfw=bO+E)NMs~ssLDQ)se@Rw zh0|^5UoD_abQ~PfczCGy%Azc)!2j6u*bxLvh>qTAb^3Jgz2h8?E-odW7x2tz$m1w= zV4|ewoA<#A80H%OUW}<1*LWR&=T3}x!z!7wvhsQ*HOG;&z#R^=`%`oPDR1=m8|UA< z$HT`LwkapZ*q0yM6{j)d{6!j88~-Q6*`4>eN1desXEI z&uV~$JUl#f;NP!a=@kS-KfPjMZvOoF^We1d)m*Wp4(_zBeeg|^l0wSMrR?qP1qTLY zVnd*&aB^~fnVO=6s^R17`(j{#85&!~+NnD2iYVw%k(lH?e^=%RarLaDV+5G?T_7-U zqH7rIUDVWFLyDMYfH!qLyxnaH3vI+!=`kB?6 z&2UkxYC7;vv%fpK3isyb`RSUPVFGbB$ndFR_5kfV8{gt`!$U~7Yfk=%RMG^02sL+o9n?vSeH5o zlofUJ=7ZR(W`je}qRKs|j3ABk{Ks+$fJw>7-vim6fJ4sO{42jFbk~4b9YLK4Y6D0v z0CkI02xVjOt-fL}iiQhcY83Nb%_O0gJtw=r*qe5=)}(9UOW=l$_e*_n~yFlG+Cb zQdmUIm7u*HzVcZ(`Y)b~UBO>w%UzILkqN1X+600khziNx(NPN?QfbP9WHA>!k#JhB zdk;_7GLTSAbBMQt0uJ)2G(2jO3GM>CClRyz&<)H!C248qV-gsYsV(LEssf!c)TYE+ zw>YMzr~7c5lEd649~*-d@xO_+wX*U7YsLo-N)+EP7A`OKPJi??hIxRiyB8+r=JW|;0E*QcIEIB7qHV!oQp={)7lfll zk~#{ZvDw?&8huYRQ4Cy9zj^Z}p*dd2DLmE=_=BrmBUl>r+L&XzZ$qtuFq(*Re^gf1 z?3sX?gFC*XOLk33sNU)H7@9L7HEyi`nWhPw-)x`o0Bqmwpja$ zy1r_K zf@PUFWmC1nw-AIn`_7T;mP$@x<+yHv5#YdGTE+D;i`ce8&VvUJ_UCIevP;@WSg=z8 z!RwHU`>TR}f7K>47`>YXu`L-3rkLHfh(HaWs>Ff(%3E}j3JhociB1No{%l{q_5{s` zz{N@WST02x9u<^O(^88tqT(s_a?iu!JKK@SXI^!2kK)IFUReqOQ4kZ*auz;q{9Cae zs`F(8f`z2_0+9YOpvTwm-_vB5EmT#vqLMU$Byp$cbRFm_3``X}O@;md!)l&U(K!hZ z$AmftJrZM6MIctJ*-$KuK=8XLZ^78lQQd;cE8~n5|eU!KtFB zRA57i^^Rm@Wt|{2jG37We!ss`mRdJv5ucRIxCuQ?cUQ(bWXvfr1>U5NNsuof49US- zZ}ED+C7r2ihhcXvIUgnB^Q)^qJi>otEm41=0TDIT;_12LMO%|0p{Zg2@}}rap&=6A zAf9D&?TJGd%&wBA|~X{i>-89bVZ!pY4cFoe4L`VNR9kX-Q%G*ZTbd-r~~39(#waeaG!Y3Xs$lm2q2 zlbu~%A{E6mK))Ls8~qT>3vLBJeR`5I%(D|MK(k&2ROK5`UK+u)JIG-pW{OG|@}Y(!Z|=vAk+;kAQxm zq^9o50b^3n{b+0+%#z!i3J{Z{NIlpFx|~(!Ltnm}#Uo4Ez1&rB@TFUn2upO4Nr~y# z>FLgqc+ipn_ypUOredd&AmD84{qWfCj|tDnj4+3kBVGN%8(D|IST$z3@pJa}&tJaO znDFz%a!J`p)hYN;X0sJ^Ddvdl;PcO)v$}N>GBYhk=ZA)eb>%GaYd;@fK~n$(VqJUn zv;CadoB)?loK%Y*G15y)N`CqBgXxIvSCaHc*}pY8ki@8S9if3{&A z`pl1o^&M2)T)#|6v`U7b$lKfVyiX;@yk^AV^O~b=`pd*P?tUqeph%(I6uo-O% zGj9y6HzKQ*KI*q92#_3#vUyiMnuQeghjUQI?BouZYcq)FTPRYQ4uHoOZ zr6>y6NLE(X!$cpqXk+0pE#Bh!1&JuZ)Li%nuOwkDv=Rde2V$#r;2%CL5WnS_NI+gG zx6hnPhLvs7LZ4I9((aenOLkp0{l@=mb>0q=`-hJoi#%tBhOFR3T0Qe9XmJ6;)PR|i zQfb2E4KFe5aD+$e01R70dBYU6Zsi<<RCtJK^d{~D@r#I! z6?=LUQu4PqccM`hctk{^S$%*u{59IfLa#kKS@-(Ys|V+<2Zi81w$QGKk#^F_Q-aZmy^2$*+lrxP-iCx3K7fmKFst*XXxzro;Qv zqO^Eb@DFRtKck_Fk3rg>NSy;cCMEl%mJbRR*xv}foI3XrJLK&1;4S9Wm-eNKTgRgNaUVE9jX7$q^Ia%&c`ugIeHEP?PAMH> zbN|`i<>GNdS9d=N1i!)K*1axOQI@ypQ-a5i)x0xG0*ZmmP2F zw69z_GIAhfv-_106O|LD$-gS1q(mg(*NYue=ihKDT(xyy>G^S5Yl8Z?Xo)PLckp%6 zut@}iKXl4J1U@B{V5|mA%FPI!xwhBc-S@!_h=1DLAQ5%?jzU1D_;=JW#Hj=q4$XY7 zmxb8+xXCGyIo3pML*fS`Jfw&3l)vm6pxm`fbj3D1B}EAI*+XRBo&GgNf4Es_dM+y(l@9K7nc8;k`h5?J{_7v1MdKu0LOD3t}!9#&`a$4cI+aO0U{Gl zaQNKnd0kHD-@O|Q<>umk;;gN(6d%|GIl)ri;<2>B*D;+4CPYtVO{=wi>UvQN>cl7OAGv_-~m5>fpTDYj{71>zEfj%7pv} zu|kLN0MEV7q$fRj`ZN#JTPQ{p0^#}JM-0%1hU=@}VAfDzGlNuUiF zUN+Epjhk%OFk-2h^mnmPx-UvVL%A?qQ*msBM)fs>t6aP*DfsB+Z+E?ZysP$I;74Vl zV`knSsYZ5J@Ov`q{{NA#3WAoFmN1PkHtFc?C3k=IIEp{7UA0a&?Cohd=WdIJo_9H*2FjO;^;#4 zq0H+YpCk&!oH!&W?maAWtIJCCLK2KfoAqdO$77?r}KJ=I}B_|s`PV8ChkD!x&Ndrh&#CAvb)(Uh&%8@ zBj$`bL)?K;!pVh&e|EF^YpFtd`C(edmBZE#u)j8&ZJb8I_cyUukIc54T$WvB|tiMR)0x_S@#{;aPQB)nob-^m9sBeN;R+y zzf4RNl)IwVrAydc&*UlQ1S`3YhJXs5V<1Q(!7QkTd|F?#Dn)K?@{BzL(WFRMYybJ} z;NZnK9EH0PD@)1W9t@odms5!`Jg5^AX&FNxCO8_`5KC=5F3MtuWJ%!<=3A(L1mM-_ zBU3^GV{~$|@LnOCgUU^3zJ2?y5rvhM99Zf4>E3Kl=U1d;>|I@t17nlo5o7^H7(_M& z5W*|GO%NkRO-L13orLz#*VDsIB7#AXHl}Gxjft-7nJ@*dfsc=`NLEziXxYmmcHV;r z>BrmR+E6)3$rl__JM57&Vc#p~6O0rq5kQ9qj9-wpG!Ws2uE?FBwJLmJK8TfmzHL;7 z26D3?lZvM1jjTg^y3R6UZ7=9(jdu2zZFi-<;bFV9YU~NYL8m>gtMc={BHYXsA$L>8Qf+ zn?wpz*+QtbmIemFfMz)epTKQ1Dk|?k?}87ae}4G%X+P9w)NG&ccf9G?*eFCqO4$F& zu?=ixtRja+(FAA`3zSFzIdSU-V+Ig-!aN|9`nDo^YSzC#qnRSiA!GmTo;M@!^_J{08gmHBocUy zjI693)YJ4MuI#T~zaHQ`kcTq^%xIbB70%3Xr^0|6l*Q7vlEB|bzkUt%9tmItJ`7bd z*z_BM4@N*<$k-xFKiJvXb3ts(q2NJz8Q^pnzY%Gv{m{BG)_39CC|EVz2#j&r)2FJ? z07vKM;vg#)%eFYHgGTW1kez>XTlvBTQ6x1;BzOmH{Y57ax_y=_2z1f-`Mkv>0pcm4 zp=05Zq8h5<9Vllj1WZqyt|%r>L?Gy24OI}?UbL2jo|6U-EURKo3S8*CmIlKS80lpg|3$mfFEc zn-R+;``(@1=GOR+{IU5iz)2zi4hS<%Q=*JQ6NAlxBN1cJXC)*go`V*AnVo$&TQ<#% zHHbjjO9-CabQHgU2yKw*!lz!mX!#ekSHT}qd|D*OzSOULZbq8zT_O9O;#-)weSk5r zkr6thf_{*X7(J7o6$#lTWuQ0gISKDsE6<;UCd2yYCT=z9t48$P_3@60o7d?}1 zLU3vb(qt^$il}&cxyhH+`EkPKpxU`Jr7~KL=fY4mJznAij5neh+@s*w8*^p|j2rp1 zFF1PWM9(_&Ah$8pIjPMX=!cEfuQ`K zgy8@{g4k$6d^`!BC_vnH>^p?`Lm?C?o)o~Mfr~65E!_s?nK&Ku=8kOhn>RFJVPW8- zDuAlk{wNbPTwGj^PEJC1q9f>b?~a5D2gR4zDz9Htqx<@2Ku)l#y+RHD8r7f-zDszx zW=82}01t?)*Kru6j}!?T0voIxp@ET9+zOO27pA>V_X@;dRjgaAJi=9)m}u4idv$G5 zDEqY}o)1kK$BJlaH~!%W$OH4oWJDrOlPe&URs2lZa!6Hvufn)66V>F8A6o_4?G4#+ z%V<_{4SRdu?0d$dEX@dlarxiP%iBsKh8$|#JXF8@`}b9?t-0XOY5=t!1SNtB+x+U) zE~Nx+42U@*7^0%0B&L`E8VU*}q?F*%?_-EWW|kbQ5s-YVpz{)rJtzS5AYW*JFgI2? z5ue6fh%d?0(IdTq2$$b4-|Rz@k`gve+wkgylAGNOD03bC?<_hQDQY$!ieVE}A{%FC zD#VC+1q6sO-N`COD7(b@52#9vm@|nZN4|xfK@dGsSr{Rt4^?hlP@PSj1I2DkOAX4- zLUD}G&W?*y+^Bvfi4I2$03d1@%HS5k>5}z{G;5s$LGxqkyEu;Q+OHWKY9JV@$G?@N zcjp*0f{g3Kg4lyQSwx>>1rSApq;7W?v^fT?ZQ-TVS5B)Igv;gM;fRR{*XC(bY4Y3Kw$r&LyUm&+YLOGWxS@D zhDUkS*t5$sA5k$h0l&`L+LGbC8U&xMM7|;D`{VNR7DQrx!B855 zu1hpnxK4;ohUKV+laY|T+y(_hWg1^d=C*rye4XikNhoUl$YVQa) z3EFEnn=fi;#v5v>L3B$r1%~BX;V)RU1en@dT~#0ud=776tV8lpC+x> zjUm+hd}PFWV#Sgkt&o(=Mah(UNM4nKnQ!G?ov!MWeNrRgKpe$CG-;{tH-w?wKL%@X=G!L40n!3?y1S}_f{5G+u;U#xf zDX`|{`#X}A@ zh*(4OyLYKLPQe#bSuAJYPDiXC;2t)`K>wmCGo0KB1CyQ^tk?-9;X2e0@?C7f7;B&v zgMw$za2N=SN&|u@4s342%oLpz@P|A80j0|N1Y+bU(kE-;aZx+0`m)E5w;1%luZ6=_ zz;X)bDpO4Rp8l6H*uCj*srR@CES!%$6WZ!d97%vqMr@V!D**uk1pd)ENL*P4=}Mfp z0b?n1{dA-ut|lFIQbjDy4xS-x{GT^2f9+sYI8cOsg-E`fy!@&0C@QRuvxDVM?XO>}Be;&>3xUjb=?!fPh=}aO zU6gSS5kGnm$IK)(O%4iw!zi{Qk<85!{TSgz;;;+Aneq!jMtA|}+tJ-KGJv#;Ki=yP z+YE?6*Le(2L(XrJ0=V?U6NjBIWA~H*mvRD^Qqa)6gl?CRolT2^C2CP#sP>*bj~8(u z43i?LUe6& ztn(KBhfzc^CJt4A@{scoMi9)pYw(+tVF3@14zQzTs9w4x0gF>f4a~45%{nj=VZO4~ zMrR;>z78C^14^aiK}BoeY=&SEikw^IzAs9WQWn$6uf2YDkbI8cDy&fwI3x0>7R*_ zG$n50s0BIg-kU%KM~W>pPKV+MFE2m8E%ay5Gd7|h+{ccw<4l;=@#DmC=EnF6oZYMOoVtz(%2_D%GaoBxAR%l462`_OVr@7dyc4lo z9GCfG{auEJApaQ|y8-cRE?%U8<1Xd$ZWnA9?5`pgZJe1%xPAL3!7mF6_5l!}YZGw7kj<6n_G!b6SD&chA7gc=sAF9C5a7R+cIE#GHpIwkRLtI4U@ah$3Ye+usb!ThuK_S>Ma z5$+glMNr5PP-!GIUYo|aaZ9-GzUfEd;(GS~b9ny$IY<9HNQH+|EcSO8ap<3