diff --git a/.gitignore b/.gitignore index 3311046..e67cf2b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist *.egg-info *~ *.png +*ipynb_checkpoints* diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0957235 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +FROM andrewosh/binder-base + +MAINTAINER Boris Leistedt + +USER root + +# Add dependency +RUN apt-get update +RUN apt-get install -qy texlive-full + +USER main diff --git a/README.rst b/README.rst index 672f9f2..1e4816b 100644 --- a/README.rst +++ b/README.rst @@ -6,5 +6,15 @@ in a journal or on the internet. With a short Python script and an intuitive model-building syntax you can design directed and undirected graphs and save them in any formats that matplotlib supports. -Get more information at: `daft-pgm.org `_ +Get more information at `daft-pgm.org `_ + +Try making some PGMs with the `example notebooks `_ + +.. image:: http://mybinder.org/badge.svg + :target: http://mybinder.org:/repo/dfm/daft + + +(You may need to `rebuild the binder `_.) + ************************************************************** + diff --git a/daft.py b/daft.py index fbdc0c8..7271b41 100644 --- a/daft.py +++ b/daft.py @@ -9,9 +9,24 @@ from matplotlib.patches import Ellipse from matplotlib.patches import FancyArrow from matplotlib.patches import Rectangle as Rectangle +from matplotlib.text import Annotation import numpy as np +class Tree(object): + """ + Simple generic tree implementation for storing + and connecting artists when rendering the PGM. + """ + def __init__(self, data=None, branches=None): + self.root = data + if branches is None: + self.branches = [] + else: + self.branches = branches + + def add_branch(self, obj): + self.branches.append(obj) class PGM(object): """ @@ -111,7 +126,22 @@ def add_plate(self, plate): self._plates.append(plate) return None - def render(self): + def __str__(self): + """ + Print the positions of :class:`Plate` and :class:`Node` objects in + the model. This is useful if you interactively edited the model + and want to know the new parameters for later reuse. + """ + st = "" + for name in self._nodes: + st += self._nodes[name].__str__() + "\n" + + for plate in self._plates: + st += plate.__str__() + "\n" + + return st + + def render(self, interactive=False): """ Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in the model. This will create a new figure with the correct dimensions @@ -120,18 +150,95 @@ def render(self): """ self.figure = self._ctx.figure() self.ax = self._ctx.ax() + self.artistTreeList = {} + # Artist tree will contain a dictionary of Nodes and Plates + # with pointers to their artists (lines, ellipses, text, arrows) + + for name in self._nodes: + artistTree = self._nodes[name].render(self._ctx) + self.artistTreeList.update({self._nodes[name]: artistTree}) for plate in self._plates: - plate.render(self._ctx) + artistTree = plate.render(self._ctx) + self.artistTreeList.update({plate: artistTree}) for edge in self._edges: edge.render(self._ctx) - - for name in self._nodes: - self._nodes[name].render(self._ctx) + # Add each arrow to the node1 and node2 trees. + self.artistTreeList[edge.node1].add_branch(edge) + self.artistTreeList[edge.node2].add_branch(edge) + + if interactive: + # Collect artists + self.artists = [key.root for key in self.artistTreeList.values()] + tolerance = 5 # some tolerance for grabbing artists + for artist in self.artists: + artist.set_picker(tolerance) + self.currently_dragging = False + self.current_artist = None + self.offset = (0, 0) + for canvas in set(artist.figure.canvas for artist in self.artists): + canvas.mpl_connect('button_press_event', self.on_press) + canvas.mpl_connect('button_release_event', self.on_release) + canvas.mpl_connect('pick_event', self.on_pick) + canvas.mpl_connect('motion_notify_event', self.on_motion) + plt.show() return self.ax + def on_press(self, event): + """Event: click""" + self.currently_dragging = True + + def on_release(self, event): + """Event: releasing artists""" + self.currently_dragging = False + self.current_artist = None + + def on_pick(self, event): + """ + Picking artists + """ + if self.current_artist is None: + self.current_artist = event.artist + x1, y1 = event.mouseevent.xdata, event.mouseevent.ydata + # Offset of artist position with respect to click + if isinstance(self.current_artist, Rectangle): + x0, y0 = self.current_artist.xy + elif isinstance(self.current_artist, Ellipse): + x0, y0 = self.current_artist.center + else: + x0, y0 = x1, y1 + self.offset = (x0 - x1), (y0 - y1) + + def on_motion(self, event): + """ + Moving artist and changing the content of the relevant Node/Plate + """ + if not self.currently_dragging: + return + if self.current_artist is None: + return + dx, dy = self.offset # offset of artist center w.r. to click + xp, yp = event.xdata + dx, event.ydata + dy # plot space + xm, ym = self._ctx.invconvert(xp, yp) # model space + for k, v in self.artistTreeList.items(): + if v.root == self.current_artist: + k.move(xm, ym, xp, yp) + for key, tree in self.artistTreeList.items(): + # Traverse list to get the right Plate/Node and its artistTree + if tree.root == self.current_artist: + # Move the dependent artists + for ar in tree.branches: + if isinstance(ar, Edge): + ar.ar.remove() # removing artist + ar.render(self._ctx) # drawing it again + if isinstance(ar, (Rectangle, Annotation)): + ar.xy = xp, yp + if isinstance(ar, Ellipse): + ar.center = xp, yp + self.current_artist.figure.canvas.draw() + class Node(object): """ @@ -190,8 +297,9 @@ def __init__(self, name, content, x, y, scale=1, aspect=None, # Coordinates and dimensions. self.x, self.y = x, y self.scale = scale + self.scalefac = 6.0 if self.fixed: - self.scale /= 6.0 + self.scale /= self.scalefac self.aspect = aspect # Display parameters. @@ -204,6 +312,26 @@ def __init__(self, name, content, x, y, scale=1, aspect=None, else: self.label_params = None + def __str__(self): + """ + Print the input parameters of + """ + st = "Node(" + st += "'" + str(self.name) + "'" + st += ", " + "r'" + str(self.content) + "'" + st += ", " + str(self.x) + if self.fixed: + st += ", scale=" + str(self.scalefac * self.scale) + else: + st += ", scale=" + str(self.scale) + st += ", " + str(self.y) + for atnm in ['aspect', 'observed', 'fixed', 'offset', 'plot_params', 'label_params']: + at = getattr(self, atnm) + if at is not None: + st += ", " + atnm + "=" + str(at) + st += ")" + return st + def render(self, ctx): """ Render the node. @@ -256,6 +384,7 @@ def render(self, ctx): else: aspect = ctx.aspect + self.artistTree = Tree() # Set up an observed node. Note the fc INSANITY. if self.observed: # Update the plotting parameters depending on the style of @@ -276,6 +405,7 @@ def render(self, ctx): bg = Ellipse(xy=ctx.convert(self.x, self.y), width=w, height=h, **p) ax.add_artist(bg) + self.artistTree.add_branch(bg) # Reset the face color. p["fc"] = fc @@ -283,20 +413,25 @@ def render(self, ctx): # Draw the foreground ellipse. if ctx.observed_style == "inner" and not self.fixed: p["fc"] = "none" - el = Ellipse(xy=ctx.convert(self.x, self.y), + self.artistTree.root = Ellipse(xy=ctx.convert(self.x, self.y), width=diameter * aspect, height=diameter, **p) - ax.add_artist(el) + ax.add_artist(self.artistTree.root) # Reset the face color. p["fc"] = fc # Annotate the node. - ax.annotate(self.content, ctx.convert(self.x, self.y), + an = ax.annotate(self.content, ctx.convert(self.x, self.y), xycoords="data", xytext=self.offset, textcoords="offset points", **l) + self.artistTree.add_branch(an) + + return self.artistTree - return el + def move(self, xm, ym, xp, yp): + self.x, self.y = xm, ym + self.artistTree.root.center = xp, yp class Edge(object): @@ -382,7 +517,7 @@ def render(self, ctx): # Add edge annotation. if "label" in self.plot_params: x, y, dx, dy = self._get_coords(ctx) - ax.annotate(self.plot_params["label"], + an = ax.annotate(self.plot_params["label"], [x + 0.5 * dx, y + 0.5 * dy], xycoords="data", xytext=[0, 3], textcoords="offset points", ha="center", va="center") @@ -394,13 +529,12 @@ def render(self, ctx): p["head_width"] = p.get("head_width", 0.1) # Build an arrow. - ar = FancyArrow(*self._get_coords(ctx), width=0, + self.ar = FancyArrow(*self._get_coords(ctx), width=0, length_includes_head=True, **p) # Add the arrow to the axes. - ax.add_artist(ar) - return ar + ax.add_artist(self.ar) else: p["color"] = p.get("color", "k") @@ -408,8 +542,8 @@ def render(self, ctx): x, y, dx, dy = self._get_coords(ctx) # Plot the line. - line = ax.plot([x, x + dx], [y, y + dy], **p) - return line + self.ar = ax.plot([x, x + dx], [y, y + dy], **p) + return self.ar class Plate(object): @@ -445,8 +579,26 @@ def __init__(self, rect, label=None, label_offset=[5, 5], shift=0, self.shift = shift self.rect_params = dict(rect_params) self.bbox = dict(bbox) + self.bbox['alpha'] = 0 + self.bbox['fc'] = 'w' + self.bbox['ec'] = None self.position = position + def __str__(self): + """ + Print the input parameters of + """ + st = "Plate(" + st += str(self.rect) + st += ", label=r'" + self.label + "'" + st += ", position='" + self.position + "'" + for atnm in ['label_offset', 'shift', 'rect_params', 'bbox']: + at = getattr(self, atnm) + if at is not None: + st += ", " + atnm + "=" + str(at) + st += ")" + return st + def render(self, ctx): """ Render the plate in the given axes. @@ -485,13 +637,18 @@ def render(self, ctx): raise RuntimeError("Unknown positioning string: {0}" .format(self.position)) - ax.annotate(self.label, pos, xycoords="data", + an = ax.annotate(self.label, pos, xycoords="data", xytext=offset, textcoords="offset points", bbox=self.bbox, horizontalalignment=ha) - return rect + self.artistTree = Tree(rect, [an]) + return self.artistTree + + def move(self, xm, ym, xp, yp): + self.x, self.y = xm, ym + self.artistTree.root.xy = xp, yp class _rendering_context(object): """ @@ -585,6 +742,14 @@ def convert(self, *xy): assert len(xy) == 2 return self.grid_unit * (np.atleast_1d(xy) - self.origin) + def invconvert(self, *xy): + """ + Convert from plot coordinates to model coordinates. + + """ + assert len(xy) == 2 + return self.origin + np.atleast_1d(xy) / self.grid_unit + def _pop_multiple(d, default, *args): """ diff --git a/examples/astronomy.py b/examples/astronomy.py index d2763fa..b6b64f3 100644 --- a/examples/astronomy.py +++ b/examples/astronomy.py @@ -155,6 +155,6 @@ pgm.add_edge("cosmic rays", "noise patch") # Render and save. -pgm.render() +pgm.render(interactive=True) pgm.figure.savefig("astronomy.pdf") pgm.figure.savefig("astronomy.png", dpi=150) diff --git a/examples/badfont.py b/examples/badfont.py index 2b8902b..de27585 100644 --- a/examples/badfont.py +++ b/examples/badfont.py @@ -26,6 +26,6 @@ pgm.add_edge("confused", "ugly") pgm.add_edge("ugly", "bad") pgm.add_edge("confused", "bad") -pgm.render() +pgm.render(interactive=True) pgm.figure.savefig("badfont.pdf") pgm.figure.savefig("badfont.png", dpi=150) diff --git a/examples/bca.py b/examples/bca.py index 7b8c0b0..e4fc6e9 100644 --- a/examples/bca.py +++ b/examples/bca.py @@ -11,6 +11,6 @@ pgm.add_plate(daft.Plate([0.5, 2.25, 1, 1.25], label=r"data $n$")) pgm.add_edge("a", "b") pgm.add_edge("b", "c") - pgm.render() + pgm.render(interactive=True) pgm.figure.savefig("bca.pdf") pgm.figure.savefig("bca.png", dpi=150) diff --git a/examples/classic.ipynb b/examples/classic.ipynb new file mode 100644 index 0000000..bae7bdd --- /dev/null +++ b/examples/classic.ipynb @@ -0,0 +1,919 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Interactive PGM Plotting\n", + "\n", + "_Boris Leistedt, October 2016_\n", + "\n", + "In this notebook we will make a simple PGM with `daft`, and then explore the interactive capabilities of the `PGM` class. \n", + "\n", + "### Requirements\n", + "\n", + "You will need `matplotlib`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib notebook \n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import rc\n", + "\n", + "import sys\n", + "sys.path.append('../.')\n", + "import daft\n", + "rc(\"font\", family=\"serif\", size=12)\n", + "rc(\"text\", usetex=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "application/javascript": [ + "/* Put everything inside the global mpl namespace */\n", + "window.mpl = {};\n", + "\n", + "mpl.get_websocket_type = function() {\n", + " if (typeof(WebSocket) !== 'undefined') {\n", + " return WebSocket;\n", + " } else if (typeof(MozWebSocket) !== 'undefined') {\n", + " return MozWebSocket;\n", + " } else {\n", + " alert('Your browser does not have WebSocket support.' +\n", + " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", + " 'Firefox 4 and 5 are also supported but you ' +\n", + " 'have to enable WebSockets in about:config.');\n", + " };\n", + "}\n", + "\n", + "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", + " this.id = figure_id;\n", + "\n", + " this.ws = websocket;\n", + "\n", + " this.supports_binary = (this.ws.binaryType != undefined);\n", + "\n", + " if (!this.supports_binary) {\n", + " var warnings = document.getElementById(\"mpl-warnings\");\n", + " if (warnings) {\n", + " warnings.style.display = 'block';\n", + " warnings.textContent = (\n", + " \"This browser does not support binary websocket messages. \" +\n", + " \"Performance may be slow.\");\n", + " }\n", + " }\n", + "\n", + " this.imageObj = new Image();\n", + "\n", + " this.context = undefined;\n", + " this.message = undefined;\n", + " this.canvas = undefined;\n", + " this.rubberband_canvas = undefined;\n", + " this.rubberband_context = undefined;\n", + " this.format_dropdown = undefined;\n", + "\n", + " this.image_mode = 'full';\n", + "\n", + " this.root = $('
');\n", + " this._root_extra_style(this.root)\n", + " this.root.attr('style', 'display: inline-block');\n", + "\n", + " $(parent_element).append(this.root);\n", + "\n", + " this._init_header(this);\n", + " this._init_canvas(this);\n", + " this._init_toolbar(this);\n", + "\n", + " var fig = this;\n", + "\n", + " this.waiting = false;\n", + "\n", + " this.ws.onopen = function () {\n", + " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", + " fig.send_message(\"send_image_mode\", {});\n", + " fig.send_message(\"refresh\", {});\n", + " }\n", + "\n", + " this.imageObj.onload = function() {\n", + " if (fig.image_mode == 'full') {\n", + " // Full images could contain transparency (where diff images\n", + " // almost always do), so we need to clear the canvas so that\n", + " // there is no ghosting.\n", + " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", + " }\n", + " fig.context.drawImage(fig.imageObj, 0, 0);\n", + " };\n", + "\n", + " this.imageObj.onunload = function() {\n", + " this.ws.close();\n", + " }\n", + "\n", + " this.ws.onmessage = this._make_on_message_function(this);\n", + "\n", + " this.ondownload = ondownload;\n", + "}\n", + "\n", + "mpl.figure.prototype._init_header = function() {\n", + " var titlebar = $(\n", + " '
');\n", + " var titletext = $(\n", + " '
');\n", + " titlebar.append(titletext)\n", + " this.root.append(titlebar);\n", + " this.header = titletext[0];\n", + "}\n", + "\n", + "\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "\n", + "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "mpl.figure.prototype._init_canvas = function() {\n", + " var fig = this;\n", + "\n", + " var canvas_div = $('
');\n", + "\n", + " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", + "\n", + " function canvas_keyboard_event(event) {\n", + " return fig.key_event(event, event['data']);\n", + " }\n", + "\n", + " canvas_div.keydown('key_press', canvas_keyboard_event);\n", + " canvas_div.keyup('key_release', canvas_keyboard_event);\n", + " this.canvas_div = canvas_div\n", + " this._canvas_extra_style(canvas_div)\n", + " this.root.append(canvas_div);\n", + "\n", + " var canvas = $('');\n", + " canvas.addClass('mpl-canvas');\n", + " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", + "\n", + " this.canvas = canvas[0];\n", + " this.context = canvas[0].getContext(\"2d\");\n", + "\n", + " var rubberband = $('');\n", + " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", + "\n", + " var pass_mouse_events = true;\n", + "\n", + " canvas_div.resizable({\n", + " start: function(event, ui) {\n", + " pass_mouse_events = false;\n", + " },\n", + " resize: function(event, ui) {\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " stop: function(event, ui) {\n", + " pass_mouse_events = true;\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " });\n", + "\n", + " function mouse_event_fn(event) {\n", + " if (pass_mouse_events)\n", + " return fig.mouse_event(event, event['data']);\n", + " }\n", + "\n", + " rubberband.mousedown('button_press', mouse_event_fn);\n", + " rubberband.mouseup('button_release', mouse_event_fn);\n", + " // Throttle sequential mouse events to 1 every 20ms.\n", + " rubberband.mousemove('motion_notify', mouse_event_fn);\n", + "\n", + " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", + " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", + "\n", + " canvas_div.on(\"wheel\", function (event) {\n", + " event = event.originalEvent;\n", + " event['data'] = 'scroll'\n", + " if (event.deltaY < 0) {\n", + " event.step = 1;\n", + " } else {\n", + " event.step = -1;\n", + " }\n", + " mouse_event_fn(event);\n", + " });\n", + "\n", + " canvas_div.append(canvas);\n", + " canvas_div.append(rubberband);\n", + "\n", + " this.rubberband = rubberband;\n", + " this.rubberband_canvas = rubberband[0];\n", + " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", + " this.rubberband_context.strokeStyle = \"#000000\";\n", + "\n", + " this._resize_canvas = function(width, height) {\n", + " // Keep the size of the canvas, canvas container, and rubber band\n", + " // canvas in synch.\n", + " canvas_div.css('width', width)\n", + " canvas_div.css('height', height)\n", + "\n", + " canvas.attr('width', width);\n", + " canvas.attr('height', height);\n", + "\n", + " rubberband.attr('width', width);\n", + " rubberband.attr('height', height);\n", + " }\n", + "\n", + " // Set the figure to an initial 600x600px, this will subsequently be updated\n", + " // upon first draw.\n", + " this._resize_canvas(600, 600);\n", + "\n", + " // Disable right mouse context menu.\n", + " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", + " return false;\n", + " });\n", + "\n", + " function set_focus () {\n", + " canvas.focus();\n", + " canvas_div.focus();\n", + " }\n", + "\n", + " window.setTimeout(set_focus, 100);\n", + "}\n", + "\n", + "mpl.figure.prototype._init_toolbar = function() {\n", + " var fig = this;\n", + "\n", + " var nav_element = $('
')\n", + " nav_element.attr('style', 'width: 100%');\n", + " this.root.append(nav_element);\n", + "\n", + " // Define a callback function for later on.\n", + " function toolbar_event(event) {\n", + " return fig.toolbar_button_onclick(event['data']);\n", + " }\n", + " function toolbar_mouse_event(event) {\n", + " return fig.toolbar_button_onmouseover(event['data']);\n", + " }\n", + "\n", + " for(var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " // put a spacer in here.\n", + " continue;\n", + " }\n", + " var button = $('