diff --git a/README.md b/README.md
index aaee4b7..756f7bc 100644
--- a/README.md
+++ b/README.md
@@ -73,21 +73,21 @@ dlclive.benchmark_videos('/path/to/exported/model', ['/path/to/video1', '/path/t
```
##### command line
```
-dlc-live-analyze /path/to/exported/model /path/to/video1 /path/to/video2 -o /path/to/output -r 1.0 0.75 0.5
+dlc-live-benchmark /path/to/exported/model /path/to/video1 /path/to/video2 -o /path/to/output -r 1.0 0.75 0.5
```
2. Display keypoints to visually inspect the accuracy of exported models on different image sizes (note, this is slow and only for testing purposes):
##### python
```python
-dlclive.benchmark_videos('/path/to/exported/model', '/path/to/video', resize=[0.5], display=True, pcutoff=0.5, display_radius=4, cmap='bmy')
+dlclive.benchmark_videos('/path/to/exported/model', '/path/to/video', resize=0.5, display=True, pcutoff=0.5, display_radius=4, cmap='bmy')
```
##### command line
```
-dlc-live-analyze /path/to/exported/model /path/to/video -r 0.5 --display --pcutoff 0.5 --display-radius 4 --cmap bmy
+dlc-live-benchmark /path/to/exported/model /path/to/video -r 0.5 --display --pcutoff 0.5 --display-radius 4 --cmap bmy
```
-3. Analyze and create a labeled video using the exported model and desired resize parameters. This option functions similar to `deeplabcut.analyze_videos` and `deeplabcut.create_labeled_video` (note, this is slow and only for testing purposes).
+3. Analyze and create a labeled video using the exported model and desired resize parameters. This option functions similar to `deeplabcut.benchmark_videos` and `deeplabcut.create_labeled_video` (note, this is slow and only for testing purposes).
##### python
```python
@@ -95,7 +95,7 @@ dlclive.benchmark_videos('/path/to/exported/model', '/path/to/video', resize=[1.
```
##### command line
```
-dlc-live-analyze /path/to/exported/model /path/to/video -r 0.5 --pcutoff 0.5 --display-radius 4 --cmap bmy --save_poses --save_video
+dlc-live-benchmark /path/to/exported/model /path/to/video -r 0.5 --pcutoff 0.5 --display-radius 4 --cmap bmy --save-poses --save-video
```
### Citation:
diff --git a/dlclive/__init__.py b/dlclive/__init__.py
index 3fc2b89..98096ac 100644
--- a/dlclive/__init__.py
+++ b/dlclive/__init__.py
@@ -9,3 +9,4 @@
from dlclive.dlclive import DLCLive
from dlclive.processor import Processor
from dlclive.bench import benchmark_model_by_size
+from dlclive.benchmark import benchmark, benchmark_videos
diff --git a/dlclive/bench.py b/dlclive/bench.py
index fc95612..837969b 100644
--- a/dlclive/bench.py
+++ b/dlclive/bench.py
@@ -10,6 +10,7 @@
import os
import time
import sys
+import warnings
import argparse
import pickle
import subprocess
@@ -99,8 +100,7 @@ def get_system_info() -> dict:
'dlclive_version': VERSION
}
-
-def run_benchmark(model_path, video_path, resize=None, pixels=None, n_frames=10000, print_rate=False) -> typing.Tuple[np.ndarray, int, bool]:
+def run_benchmark(model_path, video_path, tf_config=None, resize=None, pixels=None, n_frames=10000, print_rate=False, display=False, pcutoff=0.0, display_radius=3) -> typing.Tuple[np.ndarray, int, bool]:
""" Benchmark on inference times for a given DLC model and video
Parameters
@@ -139,13 +139,13 @@ def run_benchmark(model_path, video_path, resize=None, pixels=None, n_frames=100
### initialize live object
- live = DLCLive(model_path, resize=resize)
+ live = DLCLive(model_path, tf_config=tf_config, resize=resize, display=display, pcutoff=pcutoff, display_radius=display_radius)
live.init_inference(frame)
TFGPUinference = True if len(live.outputs) == 1 else False
### perform inference
- iterator = range(n_frames) if print_rate else tqdm(range(n_frames))
+ iterator = range(n_frames) if (print_rate) or (display) else tqdm(range(n_frames))
inf_times = np.zeros(n_frames)
for i in iterator:
@@ -155,7 +155,7 @@ def run_benchmark(model_path, video_path, resize=None, pixels=None, n_frames=100
if not ret:
warnings.warn("Did not complete {:d} frames. There probably were not enough frames in the video {}.".format(n_frames, video_path))
break
-
+
start_pose = time.time()
live.get_pose(frame)
inf_times[i] = time.time() - start_pose
@@ -238,7 +238,7 @@ def read_pickle(filename):
with open(filename, "rb") as handle:
return pickle.load(handle)
-def benchmark_model_by_size(model_path, video_path, fn_ind, out_dir=None, n_frames=10000, resize=None, pixels=None, print_rate=False):
+def benchmark_model_by_size(model_path, video_path, output=None, n_frames=10000, tf_config=None, resize=None, pixels=None, print_rate=False, display=False, pcutoff=0.5, display_radius=3):
"""Benchmark DLC model by image size
Parameters
@@ -278,17 +278,12 @@ def benchmark_model_by_size(model_path, video_path, fn_ind, out_dir=None, n_fram
### initialize full inference times
- #inf_times = np.zeros((len(resize), n_frames))
- #pixels_out = np.zeros(len(resize))
- print(resize)
-
# get system info once, shouldn't change between runs
sys_info = get_system_info()
for i in range(len(resize)):
sys_info = get_system_info()
- #print("Your system info:", sys_info)
datafilename=get_savebenchmarkfn(sys_info ,i, fn_ind, out_dir=out_dir)
#Check if a subset was already been completed?
@@ -313,19 +308,26 @@ def main():
parser.add_argument('model_path', type=str)
parser.add_argument('video_path', type=str)
parser.add_argument('-o', '--output', type=str, default=os.getcwd())
- parser.add_argument('-n', '--n_frames', type=int, default=10000)
+ parser.add_argument('-n', '--n-frames', type=int, default=10000)
parser.add_argument('-r', '--resize', type=float, nargs='+')
parser.add_argument('-p', '--pixels', type=float, nargs='+')
parser.add_argument('-v', '--print_rate', default=False, action='store_true')
+ parser.add_argument('-d', '--display', default=False, action='store_true')
+ parser.add_argument('-l', '--pcutoff', default=0.5, type=float)
+ parser.add_argument('-s', '--display-radius', default=3, type=int)
args = parser.parse_args()
+
benchmark_model_by_size(args.model_path,
args.video_path,
output=args.output,
resize=args.resize,
pixels=args.pixels,
n_frames=args.n_frames,
- print_rate=args.print_rate)
+ print_rate=args.print_rate,
+ display=args.display,
+ pcutoff=args.pcutoff,
+ display_radius=args.display_radius)
if __name__ == "__main__":
diff --git a/dlclive/benchmark.py b/dlclive/benchmark.py
new file mode 100644
index 0000000..a18b8d1
--- /dev/null
+++ b/dlclive/benchmark.py
@@ -0,0 +1,492 @@
+"""
+DeepLabCut Toolbox (deeplabcut.org)
+© A. & M. Mathis Labs
+
+Licensed under GNU Lesser General Public License v3.0
+"""
+
+
+import platform
+import os
+import time
+import sys
+import warnings
+import subprocess
+import typing
+import pickle
+import colorcet as cc
+from PIL import ImageColor
+import ruamel
+import pandas as pd
+
+try:
+ from pip._internal.operations import freeze
+except ImportError:
+ from pip.operations import freeze
+
+from tqdm import tqdm
+import numpy as np
+import tensorflow as tf
+import cv2
+
+from dlclive import DLCLive
+from dlclive import VERSION
+from dlclive import __file__ as dlcfile
+
+
+def get_system_info() -> dict:
+ """ Return summary info for system running benchmark
+ Returns
+ -------
+ dict
+ Dictionary containing the following system information:
+ * ``host_name`` (str): name of machine
+ * ``op_sys`` (str): operating system
+ * ``python`` (str): path to python (which conda/virtual environment)
+ * ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information)
+ * ``freeze`` (list): list of installed packages and versions
+ * ``python_version`` (str): python version
+ * ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit
+ * ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION`
+ """
+
+
+ ### get os
+
+ op_sys = platform.platform()
+ host_name = platform.node().replace(' ', '')
+
+ # A string giving the absolute path of the executable binary for the Python interpreter, on systems where this makes sense.
+ if platform.system() == 'Windows':
+ host_python = sys.executable.split(os.path.sep)[-2]
+ else:
+ host_python = sys.executable.split(os.path.sep)[-3]
+
+ # try to get git hash if possible
+ dlc_basedir = os.path.dirname(os.path.dirname(dlcfile))
+ git_hash = None
+ try:
+ git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=dlc_basedir)
+ git_hash = git_hash.decode('utf-8').rstrip('\n')
+ except subprocess.CalledProcessError:
+ # not installed from git repo, eg. pypi
+ # fine, pass quietly
+ pass
+
+ # get device info (GPU or CPU)
+ dev = None
+ if tf.test.is_gpu_available():
+ gpu_name = tf.test.gpu_device_name()
+ from tensorflow.python.client import device_lib
+ dev_desc = [d.physical_device_desc for d in device_lib.list_local_devices() if d.name == gpu_name]
+ dev = [d.split(",")[1].split(':')[1].strip() for d in dev_desc]
+ dev_type = "GPU"
+ else:
+ from cpuinfo import get_cpu_info
+ dev = [get_cpu_info()['brand']]
+ dev_type = "CPU"
+
+ return {
+ 'host_name': host_name,
+ 'op_sys' : op_sys,
+ 'python': host_python,
+ 'device_type': dev_type,
+ 'device': dev,
+ 'freeze': list(freeze.freeze()), # pip freeze to get versions of all packages
+ 'python_version': sys.version,
+ 'git_hash': git_hash,
+ 'dlclive_version': VERSION
+ }
+
+def benchmark(model_path,
+ video_path,
+ tf_config=None,
+ resize=None,
+ pixels=None,
+ n_frames=1000,
+ print_rate=False,
+ display=False,
+ pcutoff=0.0,
+ display_radius=3,
+ cmap='bmy',
+ save_poses=False,
+ save_video=False,
+ output=None) -> typing.Tuple[np.ndarray, int, bool]:
+ """ Analyze DeepLabCut-live exported model on a video:
+ Calculate inference time,
+ display keypoints, or
+ get poses/create a labeled video
+
+ Parameters
+ ----------
+ model_path : str
+ path to exported DeepLabCut model
+ video_path : str
+ path to video file
+ tf_config : :class:`tensorflow.ConfigProto`
+ tensorflow session configuration
+ resize : int, optional
+ resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None
+ pixels : int, optional
+ downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None
+ n_frames : int, optional
+ number of frames to run inference on, by default 1000
+ print_rate : bool, optional
+ flat to print inference rate frame by frame, by default False
+ display : bool, optional
+ flag to display keypoints on images. Useful for checking the accuracy of exported models.
+ pcutoff : float, optional
+ likelihood threshold to display keypoints
+ display_radius : int, optional
+ size (radius in pixels) of keypoint to display
+ cmap : str, optional
+ a string indicating the :package:`colorcet` colormap, `options here `, by default "bmy"
+ save_poses : bool, optional
+ flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False
+ save_video : bool, optional
+ flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False
+ output : str, optional
+ path to directory to save pose and/or video file. If not specified, will use the directory of video_path, by default None
+
+ Returns
+ -------
+ :class:`numpy.ndarray`
+ vector of inference times
+ float
+ number of pixels in each image
+
+ Example
+ -------
+ Return a vector of inference times for 10000 frames:
+ dlclive.benchmark('/my/exported/model', 'my_video.avi', n_frames=10000)
+
+ Return a vector of inference times, resizing images to half the width and height for inference
+ dlclive.benchmark('/my/exported/model', 'my_video.avi', n_frames=10000, resize=0.5)
+
+ Display keypoints to check the accuracy of an exported model
+ dlclive.benchmark('/my/exported/model', 'my_video.avi', display=True)
+
+ Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video`
+ dlclive.benchmark('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True)
+ """
+
+ ### load video
+
+ cap = cv2.VideoCapture(video_path)
+ ret, frame = cap.read()
+ n_frames = n_frames if n_frames < cap.get(cv2.CAP_PROP_FRAME_COUNT)-1 else cap.get(cv2.CAP_PROP_FRAME_COUNT)-1
+ n_frames = int(n_frames)
+ im_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+
+
+ ### get resize factor
+
+ if pixels is not None:
+ resize = np.sqrt(pixels / (im_size[0] * im_size[1]))
+ if resize is not None:
+ im_size = (int(im_size[0]*resize), int(im_size[1]*resize))
+
+ ### initialize live object
+
+ live = DLCLive(model_path, tf_config=tf_config, resize=resize, display=display, pcutoff=pcutoff, display_radius=display_radius, display_cmap=cmap)
+ live.init_inference(frame)
+ TFGPUinference = True if len(live.outputs) == 1 else False
+
+ ### create video writer
+
+ if save_video:
+ colors = None
+ out_dir = output if output is not None else os.path.dirname(os.path.realpath(video_path))
+ out_vid_base = os.path.basename(video_path)
+ out_vid_file = os.path.normpath(f"{out_dir}/{os.path.splitext(video_path)[0]}_DLCLIVE_LABELED.avi")
+ fourcc = cv2.VideoWriter_fourcc(*'DIVX')
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ vwriter = cv2.VideoWriter(out_vid_file, fourcc, fps, im_size)
+
+ ### perform inference
+
+ iterator = range(n_frames) if (print_rate) or (display) else tqdm(range(n_frames))
+ inf_times = np.zeros(n_frames)
+ poses = []
+
+ for i in iterator:
+
+ ret, frame = cap.read()
+
+ if not ret:
+ warnings.warn("Did not complete {:d} frames. There probably were not enough frames in the video {}.".format(n_frames, video_path))
+ break
+
+ start_pose = time.time()
+ poses.append(live.get_pose(frame))
+ inf_times[i] = time.time() - start_pose
+
+ if save_video:
+
+ if colors is None:
+ all_colors = getattr(cc, cmap)
+ colors = [ImageColor.getcolor(c, "RGB")[::-1] for c in all_colors[::int(len(all_colors)/poses[-1].shape[0])]]
+
+ this_pose = poses[-1]
+ for j in range(this_pose.shape[0]):
+ if this_pose[j, 2] > pcutoff:
+ x = int(this_pose[j, 0])
+ y = int(this_pose[j, 1])
+ frame = cv2.circle(frame, (x, y), display_radius, colors[j], thickness=-1)
+
+ if resize is not None:
+ frame = cv2.resize(frame, im_size)
+ vwriter.write(frame)
+
+ if print_rate:
+ print("pose rate = {:d}".format(int(1 / inf_times[i])))
+
+ if print_rate:
+ print("mean pose rate = {:d}".format(int(np.mean(1/inf_times))))
+
+ ### close video and tensorflow session
+
+ cap.release()
+ live.close()
+
+ if save_video:
+ vwriter.release()
+
+ if save_poses:
+
+ cfg_path = os.path.normpath(f"{model_path}/pose_cfg.yaml")
+ ruamel_file = ruamel.yaml.YAML()
+ dlc_cfg = ruamel_file.load(open(cfg_path, 'r'))
+ bodyparts = dlc_cfg['all_joints_names']
+ poses = np.array(poses)
+ poses = poses.reshape((poses.shape[0], poses.shape[1]*poses.shape[2]))
+ pdindex = pd.MultiIndex.from_product([bodyparts, ['x', 'y', 'likelihood']], names=['bodyparts', 'coords'])
+ pose_df = pd.DataFrame(poses, columns=pdindex)
+
+ out_dir = output if output is not None else os.path.dirname(os.path.realpath(video_path))
+ out_vid_base = os.path.basename(video_path)
+ out_dlc_file = os.path.normpath(f"{out_dir}/{os.path.splitext(video_path)[0]}_DLCLIVE_POSES.h5")
+ pose_df.to_hdf(out_dlc_file, key='df_with_missing', mode='w')
+
+ return inf_times, im_size[0]*im_size[1], TFGPUinference
+
+
+def save_inf_times(sys_info,
+ inf_times,
+ pixels,
+ TFGPUinference,
+ model=None,
+ output=None):
+ """ Save inference time data collected using :function:`benchmark` with system information to a pickle file.
+ This is primarily used through :function:`benchmark_videos`
+
+ Parameters
+ ----------
+ sys_info : tuple
+ system information generated by :func:`get_system_info`
+ inf_times : :class:`numpy.ndarray`
+ array of inference times generated by :func:`benchmark`
+ pixels : float or :class:`numpy.ndarray`
+ number of pixels for each benchmark run. If an array, each index corresponds to a row in inf_times
+ TFGPUinference: bool
+ flag if using tensorflow inference or numpy inference DLC model
+ model: str, optional
+ name of model
+ output : str, optional
+ path to directory to save data. If None, uses pwd, by default None
+
+ Returns
+ -------
+ bool
+ flag indicating successful save
+ """
+
+ output = output if output is not None else os.getcwd()
+ model_type = None
+ if model is not None:
+ if 'resnet' in model:
+ model_type = 'resnet'
+ elif 'mobilenet' in model:
+ model_type = 'mobilenet'
+ else:
+ model_type = None
+
+ fn_ind = 0
+ base_name = f"benchmark_{sys_info['host_name']}_{sys_info['device'][0]}_{fn_ind}.pickle"
+ while os.path.isfile(os.path.normpath(output + '/' + base_name)):
+ fn_ind += 1
+ base_name = f"benchmark_{sys_info['host_name']}_{sys_info['device'][0]}_{fn_ind}.pickle"
+
+ data = {'model' : model,
+ 'model_type' : model_type,
+ 'TFGPUinference' : TFGPUinference,
+ 'pixels' : pixels,
+ 'inference_times' : inf_times}
+
+ data.update(sys_info)
+
+ os.makedirs(os.path.normpath(output), exist_ok=True)
+ pickle.dump(data, open(os.path.normpath(f"{output}/{base_name}"), 'wb'))
+
+ return True
+
+
+def benchmark_videos(model_path,
+ video_path,
+ output=None,
+ n_frames=1000,
+ tf_config=None,
+ resize=None,
+ pixels=None,
+ print_rate=False,
+ display=False,
+ pcutoff=0.5,
+ display_radius=3,
+ cmap='bmy',
+ save_poses=False,
+ save_video=False):
+ """Analyze videos using DeepLabCut-live exported models.
+ Analyze multiple videos and/or multiple options for the size of the video
+ by specifying a resizing factor or the number of pixels to use in the image (keeping aspect ratio constant).
+ Options to record inference times (to examine inference speed),
+ display keypoints to visually check the accuracy,
+ or save poses to an hdf5 file as in :function:`deeplabcut.benchmark_videos` and
+ create a labeled video as in :function:`deeplabcut.create_labeled_video`.
+
+ Parameters
+ ----------
+ model_path : str
+ path to exported DeepLabCut model
+ video_path : str or list
+ path to video file or list of paths to video files
+ output : str
+ path to directory to save results
+ tf_config : :class:`tensorflow.ConfigProto`
+ tensorflow session configuration
+ resize : int, optional
+ resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None
+ pixels : int, optional
+ downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None
+ n_frames : int, optional
+ number of frames to run inference on, by default 1000
+ print_rate : bool, optional
+ flat to print inference rate frame by frame, by default False
+ display : bool, optional
+ flag to display keypoints on images. Useful for checking the accuracy of exported models.
+ pcutoff : float, optional
+ likelihood threshold to display keypoints
+ display_radius : int, optional
+ size (radius in pixels) of keypoint to display
+ cmap : str, optional
+ a string indicating the :package:`colorcet` colormap, `options here `, by default "bmy"
+ save_poses : bool, optional
+ flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False
+ save_video : bool, optional
+ flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False
+
+ Example
+ -------
+ Return a vector of inference times for 10000 frames on one video or two videos:
+ dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', n_frames=10000)
+ dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000)
+
+ Return a vector of inference times, testing full size and resizing images to half the width and height for inference, for two videos
+ dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000, resize=[1.0, 0.5])
+
+ Display keypoints to check the accuracy of an exported model
+ dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', display=True)
+
+ Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video`
+ dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True)
+ """
+
+ ### convert video_paths to list
+
+ video_path = video_path if type(video_path) is list else [video_path]
+
+ ### fix resize
+
+ if pixels:
+ pixels = pixels if type(pixels) is list else [pixels]
+ resize = [None for p in pixels]
+ elif resize:
+ resize = resize if type(resize) is list else [resize]
+ pixels = [None for r in resize]
+ else:
+ resize = [None]
+ pixels = [None]
+
+ ### loop over videos
+
+ for v in video_path:
+
+ ### initialize full inference times
+
+ inf_times = np.zeros((len(resize), n_frames))
+ pixels_out = np.zeros(len(resize))
+
+ for i in range(len(resize)):
+
+ print(f"\nRun {i+1} / {len(resize)}\n")
+
+ inf_times[i], pixels_out[i], TFGPUinference = benchmark(model_path,
+ v,
+ tf_config=tf_config,
+ resize=resize[i],
+ pixels=pixels[i],
+ n_frames=n_frames,
+ print_rate=print_rate,
+ display=display,
+ pcutoff=pcutoff,
+ display_radius=display_radius,
+ cmap=cmap,
+ save_poses=save_poses,
+ save_video=save_video,
+ output=output)
+
+ ### save results
+
+ if output is not None:
+ sys_info = get_system_info()
+ save_inf_times(sys_info, inf_times, pixels_out, TFGPUinference, model=os.path.basename(model_path), output=output)
+
+
+def main():
+ """Provides a command line interface :function:`benchmark_videos`
+ """
+
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('model_path', type=str)
+ parser.add_argument('video_path', type=str, nargs='+')
+ parser.add_argument('-o', '--output', type=str, default=None)
+ parser.add_argument('-n', '--n-frames', type=int, default=1000)
+ parser.add_argument('-r', '--resize', type=float, nargs='+')
+ parser.add_argument('-p', '--pixels', type=float, nargs='+')
+ parser.add_argument('-v', '--print-rate', default=False, action='store_true')
+ parser.add_argument('-d', '--display', default=False, action='store_true')
+ parser.add_argument('-l', '--pcutoff', default=0.5, type=float)
+ parser.add_argument('-s', '--display-radius', default=3, type=int)
+ parser.add_argument('-c', '--cmap', type=str, default='bmy')
+ parser.add_argument('--save-poses', action='store_true')
+ parser.add_argument('--save-video', action='store_true')
+ args = parser.parse_args()
+
+ benchmark_videos(args.model_path,
+ args.video_path,
+ output=args.output,
+ resize=args.resize,
+ pixels=args.pixels,
+ n_frames=args.n_frames,
+ print_rate=args.print_rate,
+ display=args.display,
+ pcutoff=args.pcutoff,
+ display_radius=args.display_radius,
+ cmap=args.cmap,
+ save_poses=args.save_poses,
+ save_video=args.save_video)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dlclive/display.py b/dlclive/display.py
index 5c5fbc8..534cb90 100644
--- a/dlclive/display.py
+++ b/dlclive/display.py
@@ -19,19 +19,19 @@ class Display(object):
-----------
cmap : string
string indicating the Matoplotlib colormap to use.
- lik : float
+ pcutoff : float
likelihood threshold to display points
'''
- def __init__(self, cmap='bmy', radius=3, lik=0.5):
+ def __init__(self, cmap='bmy', radius=3, pcutoff=0.5):
""" Constructor method
"""
self.cmap = cmap
self.colors = None
self.radius = radius
- self.lik = lik
+ self.pcutoff = pcutoff
self.window = None
@@ -79,18 +79,18 @@ def display_frame(self, frame, pose=None):
draw = ImageDraw.Draw(img)
for i in range(pose.shape[0]):
- if pose[i,2] > self.lik:
+ if pose[i,2] > self.pcutoff:
try:
- x0 = pose[i,0] - self.radius if pose[i,1]-self.radius > 0 else 0
- x1 = pose[i,0] + self.radius if pose[i,1] + self.radius < im_size[1] else im_size[1]
- y0 = pose[i,1] - self.radius if pose[i,0] - self.radius > 0 else 0
- y1 = pose[i,1] + self.radius if pose[i,0] + self.radius < im_size[0] else im_size[0]
+ x0 = pose[i,0] - self.radius if pose[i,0] - self.radius > 0 else 0
+ x1 = pose[i,0] + self.radius if pose[i,0] + self.radius < im_size[1] else im_size[1]
+ y0 = pose[i,1] - self.radius if pose[i,1] - self.radius > 0 else 0
+ y1 = pose[i,1] + self.radius if pose[i,1] + self.radius < im_size[0] else im_size[0]
coords = [x0, y0, x1, y1]
draw.ellipse(coords, fill=self.colors[i], outline=self.colors[i])
- except:
- pass
+ except Exception as e:
+ print(e)
- img_tk = ImageTk.PhotoImage(image=img)
+ img_tk = ImageTk.PhotoImage(image=img, master=self.window)
self.lab.configure(image=img_tk)
self.window.update()
diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py
index d8b3654..3059b70 100644
--- a/dlclive/dlclive.py
+++ b/dlclive/dlclive.py
@@ -23,7 +23,7 @@
from dlclive.graph import read_graph, finalize_graph, get_output_nodes, get_output_tensors, extract_graph
from dlclive.pose import extract_cnn_output, argmax_pose_predict, multi_pose_predict
-from dlclive.display import Display
+from dlclive.display import Display
from dlclive import utils
from dlclive.exceptions import DLCLiveError, DLCLiveWarning
@@ -66,34 +66,51 @@ class DLCLive(object):
i) to run a forward predicting model that will predict the future pose from past history of poses (history can be stored in the processor object, but is not stored in this DLCLive object)
ii) to trigger external hardware based on pose estimation (e.g. see 'TeensyLaser' processor)
+ convert2rgb : bool, optional
+ boolean flag to convert frames from BGR to RGB color scheme
+
display : bool, optional
Display frames with DeepLabCut labels?
This is useful for testing model accuracy and cropping parameters, but it is very slow.
display_lik : float, optional
Likelihood threshold for display
+
+ display_raidus : int, optional
+ radius for keypoint display in pixels, default=3
'''
- def __init__(self, model_path,
- model_type='base', precision='FP16',
- cropping=None, dynamic=(False,.5,10), resize=None,
- processor=None, display=False, display_lik=0.5):
+ def __init__(self,
+ model_path,
+ model_type='base',
+ precision='FP32',
+ tf_config=None,
+ cropping=None,
+ dynamic=(False,.5,10),
+ resize=None,
+ convert2rgb=True,
+ processor=None,
+ display=False,
+ pcutoff=0.5,
+ display_radius=3,
+ display_cmap='bmy'):
self.path = model_path
self.cfg = None
self.model_type = model_type
+ self.tf_config = tf_config
self.precision = precision
self.cropping = cropping
self.dynamic = dynamic
self.dynamic_cropping = None
self.resize = resize
self.processor = processor
- self.display = Display(lik=display_lik) if display else None
+ self.convert2rgb = convert2rgb
+ self.display = Display(pcutoff=pcutoff, radius=display_radius, cmap=display_cmap) if display else None
self.sess = None
self.inputs = None
self.outputs = None
- self.graph = None
self.tflite_interpreter = None
self.pose = None
self.is_initialized = False
@@ -176,14 +193,13 @@ def process_frame(self, frame):
if self.resize != 1:
frame = utils.resize_frame(frame, self.resize)
-
- if frame.ndim == 2:
- frame = utils.gray_to_rgb(frame)
+ if self.convert2rgb:
+ frame = utils.img_to_rgb(frame)
return frame
- def init_inference(self, frame=None):
+ def init_inference(self, frame=None, **kwargs):
'''
Load model and perform inference on first frame -- the first inference is usually very slow.
@@ -209,15 +225,19 @@ def init_inference(self, frame=None):
if frame is None:
raise DLCLiveError("No image was passed to initialize inference. An image must be passed to the init_inference method")
- process_frame = self.process_frame(frame)
+ if frame.ndim == 2:
+ self.convert2rgb = True
+ frame = self.process_frame(frame)
### load model
if self.model_type == 'base':
graph_def = read_graph(model_file)
- self.graph, self.inputs = finalize_graph(graph_def)
- self.sess, self.outputs = extract_graph(self.graph)
+ graph = finalize_graph(graph_def)
+ self.sess, self.inputs, self.outputs = extract_graph(graph, tf_config=self.tf_config)
+
+ #self.sess, self.inputs, self.outputs = load_graph(model_file)
elif self.model_type == 'tflite':
@@ -228,13 +248,13 @@ def init_inference(self, frame=None):
# get input and output tensor names from graph_def
graph_def = read_graph(model_file)
- graph, _ = finalize_graph(graph_def)
+ graph = finalize_graph(graph_def)
output_nodes = get_output_nodes(graph)
- output_nodes = [on.replace('Placeholder_1/', '') for on in output_nodes]
+ output_nodes = [on.replace('DLC/', '') for on in output_nodes]
converter = tf.lite.TFLiteConverter.from_frozen_graph(model_file,
['Placeholder'],
output_nodes,
- input_shapes={'Placeholder' : [1, process_frame.shape[0], process_frame.shape[1], 3]})
+ input_shapes={'Placeholder' : [1, frame.shape[0], frame.shape[1], 3]})
try:
tflite_model = converter.convert()
except Exception:
@@ -251,9 +271,9 @@ def init_inference(self, frame=None):
elif self.model_type == 'tensorrt':
graph_def = read_graph(model_file)
- graph, _ = finalize_graph(graph_def)
+ graph = finalize_graph(graph_def)
output_tensors = get_output_tensors(graph)
- output_tensors = [ot.replace('Placeholder_1/', '') for ot in output_tensors]
+ output_tensors = [ot.replace('DLC/', '') for ot in output_tensors]
if (TFVER[0] > 1) | (TFVER[0]==1 & TFVER[1] >= 14):
converter = trt.TrtGraphConverter(input_graph_def=graph_def,
@@ -267,8 +287,8 @@ def init_inference(self, frame=None):
precision_mode=self.precision,
is_dynamic_op=True)
- self.graph, self.inputs = finalize_graph(graph_def)
- self.sess, self.outputs = extract_graph(self.graph)
+ graph = finalize_graph(graph_def)
+ self.sess, self.inputs, self.outputs = extract_graph(graph, tf_config=self.tf_config)
else:
@@ -276,14 +296,14 @@ def init_inference(self, frame=None):
### get pose of first frame (first inference is often very slow)
- pose = self.get_pose(frame)
+ pose = self.get_pose(frame, **kwargs)
self.is_initialized = True
return pose
- def get_pose(self, frame=None):
+ def get_pose(self, frame=None, **kwargs):
'''
Get the pose of an image
@@ -336,7 +356,8 @@ def get_pose(self, frame=None):
else:
self.pose = argmax_pose_predict(scmap, locref, self.cfg['stride'])
else:
- self.pose = pose_output[0]
+ pose = np.array(pose_output[0])
+ self.pose = pose[:, [1,0,2]]
# display image if display=True before correcting pose for cropping/resizing
@@ -359,7 +380,7 @@ def get_pose(self, frame=None):
# process the pose
if self.processor:
- self.pose = self.processor.process(self.pose)
+ self.pose = self.processor.process(self.pose, **kwargs)
return self.pose
@@ -371,3 +392,5 @@ def close(self):
self.sess.close()
self.sess = None
self.is_initialized = False
+ if self.display is not None:
+ self.display.destroy()
diff --git a/dlclive/graph.py b/dlclive/graph.py
index af43f8e..d2d16e3 100644
--- a/dlclive/graph.py
+++ b/dlclive/graph.py
@@ -20,15 +20,14 @@ def read_graph(file):
Returns
--------
- graph_def :class:`tensorflow.GraphDef`
+ graph_def :class:`tensorflow.tf.compat.v1.GraphDef`
The graph definition of the DeepLabCut model found at the object's path
'''
- graph = tf.Graph()
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(tf.gfile.GFile(file, 'rb').read())
-
- return graph_def
+ with tf.io.gfile.GFile(file, 'rb') as f:
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(f.read())
+ return graph_def
def finalize_graph(graph_def):
@@ -37,12 +36,12 @@ def finalize_graph(graph_def):
Parameters
-----------
- graph_def :class:`tensorflow.GraphDef`
+ graph_def :class:`tensorflow.compat.v1.GraphDef`
The graph of the DeepLabCut model, read using the :func:`read_graph` method
Returns
--------
- graph :class:`tensorflow.Graph`
+ graph :class:`tensorflow.compat.v1.GraphDef`
The finalized graph of the DeepLabCut model
inputs :class:`tensorflow.Tensor`
Input tensor(s) for the model
@@ -50,11 +49,10 @@ def finalize_graph(graph_def):
graph = tf.Graph()
with graph.as_default():
- inputs = tf.placeholder(tf.float32, shape=[1, None, None, 3])
- tf.import_graph_def(graph_def, {'Placeholder' : inputs}, name='Placeholder')
+ tf.import_graph_def(graph_def, name="DLC")
graph.finalize()
- return graph, inputs
+ return graph
def get_output_nodes(graph):
@@ -97,11 +95,17 @@ def get_output_tensors(graph):
'''
output_nodes = get_output_nodes(graph)
- output_tensor = [out+':0' for out in output_nodes]
+ output_tensor = [out+":0" for out in output_nodes]
return output_tensor
-def extract_graph(graph):
+def get_input_tensor(graph):
+
+ input_tensor = str(graph.get_operations()[0].name) + ":0"
+ return input_tensor
+
+
+def extract_graph(graph, tf_config=None):
'''
Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs
@@ -109,6 +113,7 @@ def extract_graph(graph):
-----------
graph :class:`tensorflow.Graph`
a tensorflow graph containing the desired model
+ tf_config :class:`tensorflow.ConfigProto`
Returns
--------
@@ -118,8 +123,10 @@ def extract_graph(graph):
the output tensor(s) for the model
'''
+ input_tensor = get_input_tensor(graph)
output_tensor = get_output_tensors(graph)
- sess = tf.Session(graph=graph)
+ sess = tf.compat.v1.Session(graph=graph, config=tf_config)
+ inputs = graph.get_tensor_by_name(input_tensor)
outputs = [graph.get_tensor_by_name(out) for out in output_tensor]
- return sess, outputs
+ return sess, inputs, outputs
diff --git a/dlclive/processor/__init__.py b/dlclive/processor/__init__.py
index 6b7b977..3e6b6b2 100644
--- a/dlclive/processor/__init__.py
+++ b/dlclive/processor/__init__.py
@@ -1 +1,2 @@
from dlclive.processor.processor import Processor
+from dlclive.processor.kalmanfilter import KalmanFilterPredictor
diff --git a/dlclive/processor/kalmanfilter.py b/dlclive/processor/kalmanfilter.py
new file mode 100644
index 0000000..c3ceaf1
--- /dev/null
+++ b/dlclive/processor/kalmanfilter.py
@@ -0,0 +1,146 @@
+"""
+DeepLabCut-live Toolbox (deeplabcut.org)
+Please see AUTHORS for contributors.
+Licensed under GNU Lesser General Public License v3.0
+"""
+
+
+import time
+import numpy as np
+from dlclive.processor import Processor
+
+
+class KalmanFilterPredictor(Processor):
+
+
+ def __init__(self,
+ adapt=True,
+ forward=0.002,
+ fps=30,
+ nderiv=2,
+ priors=[10, 1],
+ initial_var=5,
+ process_var=5,
+ dlc_var=10,
+ **kwargs):
+
+ super().__init__(**kwargs)
+
+ self.adapt=adapt
+ self.forward = forward
+ self.dt = 1.0 / fps
+ self.nderiv = nderiv
+ self.priors = np.hstack(([1e5], priors))
+ self.initial_var = initial_var
+ self.process_var = process_var
+ self.dlc_var = dlc_var
+ self.is_initialized = False
+ self.last_pose_time = 0
+
+
+ def _get_forward_model(self, dt):
+
+ F = np.zeros((self.n_states, self.n_states))
+ for d in range(self.nderiv+1):
+ for i in range(self.n_states - (d * self.bp * 2)):
+ F[i, i + (2 * self.bp * d)] = (dt ** d) / max(1, d)
+
+ return F
+
+
+ def _init_kf(self, pose):
+
+ # get number of body parts
+ self.bp = pose.shape[0]
+ self.n_states = self.bp * 2 * (self.nderiv+1)
+
+ # initialize state matrix, set position to first pose
+ self.X = np.zeros((self.n_states, 1))
+ self.X[:(self.bp * 2)] = pose[:, :2].reshape(self.bp * 2, 1)
+
+ # initialize covariance matrix, measurement noise and process noise
+ self.P = np.eye(self.n_states) * self.initial_var
+ self.R = np.eye(self.n_states) * self.dlc_var
+ self.Q = np.eye(self.n_states) * self.process_var
+
+ # initialize forward model:
+ self.F = self._get_forward_model(self.dt)
+
+ self.H = np.eye(self.n_states)
+ self.K = np.zeros((self.n_states, self.n_states))
+ self.I = np.eye(self.n_states)
+
+ # initialize priors for forward prediction step only
+ B = np.repeat(self.priors, self.bp * 2)
+ self.B = B.reshape(B.size, 1)
+
+ self.is_initialized = True
+
+
+ def _predict(self):
+
+ #self.F = self._get_forward_model(time.time()-self.last_pose_time)
+ self.Xp = np.dot(self.F, self.X)
+ self.Pp = np.dot(np.dot(self.F, self.P), self.F.T) + self.Q
+
+
+ def _get_residuals(self, pose):
+
+ z = np.zeros((self.n_states, 1))
+ z[:(self.bp * 2)] = pose[:self.bp, :2].reshape(self.bp * 2, 1)
+ for i in range(self.bp * 2, self.n_states):
+ z[i] = (z[i - (self.bp * 2)] - self.X[i - (self.bp * 2)]) / self.dt
+ self.y = z - np.dot(self.H, self.Xp)
+
+
+ def _update(self):
+
+ S = np.dot(self.H, np.dot(self.Pp, self.H.T)) + self.R
+ K = np.dot(np.dot(self.Pp, self.H.T), np.linalg.inv(S))
+ self.X = self.Xp + np.dot(K, self.y)
+ self.P = np.dot(self.I - np.dot(K, self.H), self.Pp)
+
+
+ def _get_future_pose(self, dt):
+
+ Ff = self._get_forward_model(dt)
+
+ Pf = np.diag(self.P).reshape(self.P.shape[0], 1)
+ Xf = (1 / ((1 / Pf) + (1 / self.B))) * (self.X / Pf)
+ Xfp = np.dot(Ff, Xf)
+
+ future_pose = Xfp[:(self.bp * 2)].reshape(self.bp, 2)
+
+ return future_pose
+
+
+ def _get_state_likelihood(self, pose):
+
+ liks = pose[:, 2]
+ liks_xy = np.repeat(liks, 2)
+ liks_xy_deriv = np.tile(liks_xy, self.nderiv)
+ liks_state = liks_xy_deriv.reshape(liks_xy_deriv.shape[0], 1)
+ return(liks_state)
+
+
+ def process(self, pose, **kwargs):
+
+ if not self.is_initialized:
+
+ self._init_kf(pose)
+
+ self.last_pose_time = time.time()
+ return pose
+
+ else:
+
+ self._predict()
+ self._get_residuals(pose)
+ self._update()
+
+ forward_time = (time.time() - kwargs['frame_time'] + self.forward) if self.adapt else self.forward
+ future_pose = self._get_future_pose(forward_time)
+ future_pose = np.hstack((future_pose, pose[:,2].reshape(self.bp,1)))
+
+ self.last_pose_time = time.time()
+ return future_pose
diff --git a/dlclive/processor/processor.py b/dlclive/processor/processor.py
index 6d26e91..a7d6e83 100644
--- a/dlclive/processor/processor.py
+++ b/dlclive/processor/processor.py
@@ -16,10 +16,10 @@
class Processor(object):
- def __init__(self):
+ def __init__(self, **kwargs):
pass
- def process(self, pose):
+ def process(self, pose, **kwargs):
return pose
def save(self, file=''):
diff --git a/dlclive/processor/teensy_laser/teensy_laser.py b/dlclive/processor/teensy_laser/teensy_laser.py
index 30f0479..b3ef888 100644
--- a/dlclive/processor/teensy_laser/teensy_laser.py
+++ b/dlclive/processor/teensy_laser/teensy_laser.py
@@ -8,7 +8,7 @@
Licensed under GNU Lesser General Public License v3.0
"""
-from ..processor import Processor
+from dlclive.processor.processor import Processor
import serial
import struct
import time
@@ -50,7 +50,7 @@ def stim_off(self):
self.stim_off_time.append(time.time())
- def process(self, pose):
+ def process(self, pose, **kwargs):
# define criteria to stimulate (e.g. if first point is in a corner of the video)
box = [[0,100],[0,100]]
diff --git a/dlclive/utils.py b/dlclive/utils.py
index c3d1389..244697f 100644
--- a/dlclive/utils.py
+++ b/dlclive/utils.py
@@ -55,10 +55,8 @@ def resize_frame(frame, resize=None):
an image as a numpy array
"""
-
if (resize is not None) and (resize != 1):
-
if OPEN_CV:
new_x = int(frame.shape[0] * resize)
@@ -75,10 +73,9 @@ def resize_frame(frame, resize=None):
return frame
+def img_to_rgb(frame):
+ """ Convert an image to RGB. Uses OpenCV is installed, otherwise uses pillow.
-def gray_to_rgb(frame):
- """ Convert an image from grayscale to RGB. Uses OpenCV is installed, otherwise uses pillow.
-
Parameters
----------
frame : :class:`numpy.ndarray
@@ -87,21 +84,58 @@ def gray_to_rgb(frame):
if frame.ndim == 2:
- if OPEN_CV:
+ return gray_to_rgb(frame)
- return cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
-
- else:
+ elif frame.ndim == 3:
- img = Image.fromarray(frame)
- img = img.convert('RGB')
- return np.asarray(img)
+ return bgr_to_rgb(frame)
else:
+ warnings.warn(f"Image has {frame.ndim} dimensions. Must be 2 or 3 dimensions to convert to RGB", DLCLiveWarning)
return frame
+def gray_to_rgb(frame):
+ """ Convert an image from grayscale to RGB. Uses OpenCV is installed, otherwise uses pillow.
+
+ Parameters
+ ----------
+ frame : :class:`numpy.ndarray
+ an image as a numpy array
+ """
+
+ if OPEN_CV:
+
+ return cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
+
+ else:
+
+ img = Image.fromarray(frame)
+ img = img.convert('RGB')
+ return np.asarray(img)
+
+
+def bgr_to_rgb(frame):
+ """ Convert an image from BGR to RGB. Uses OpenCV is installed, otherwise uses pillow.
+
+ Parameters
+ ----------
+ frame : :class:`numpy.ndarray
+ an image as a numpy array
+ """
+
+ if OPEN_CV:
+
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+
+ else:
+
+ img = Image.fromarray(frame)
+ img = img.convert('RGB')
+ return np.asarray(img)
+
+
def _img_as_ubyte_np(frame):
""" Converts an image as a numpy array to unsinged 8-bit integer.
As in scikit-image img_as_ubyte, converts negative pixels to 0 and converts range to [0, 255]
diff --git a/setup.py b/setup.py
index 0c0abb7..64e9e56 100644
--- a/setup.py
+++ b/setup.py
@@ -15,7 +15,7 @@
with open("README.md", "r") as fh:
long_description = fh.read()
-install_requires = ['numpy', 'ruamel.yaml', 'colorcet', 'pillow', 'py-cpuinfo==5.0.0', 'tqdm']
+install_requires = ['numpy', 'ruamel.yaml', 'colorcet', 'pillow', 'py-cpuinfo==5.0.0', 'tqdm', 'pandas', 'tables']
if find_spec('cv2') is None:
install_requires.append('opencv-python')
@@ -41,5 +41,6 @@
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
"Operating System :: OS Independent",
),
- entry_points = {'console_scripts' : ['dlc-live-bench=dlclive.bench:main']}
+ entry_points = {'console_scripts' : ['dlc-live-bench=dlclive.bench:main',
+ 'dlc-live-benchmark=dlclive.benchmark:main']}
)