"""An abstract base class for trackers used in sksurgery"""
from abc import ABCMeta, abstractmethod
import numpy as np
from sksurgerycore.algorithms.tracking_smoothing import RollingMean, \
RollingMeanRotation, quaternion_to_matrix
[docs]class SKSBaseTracker(metaclass=ABCMeta):
"""Abstract base class for trackers using in sksurgery.
Defines methods that all trackers should implement.
"""
def __init__(self, configuration = None, tracked_objects = None):
self.buffer_size = 1
if configuration is not None:
self.buffer_size = configuration.get('smoothing buffer', 1)
self.use_quaternions = False
if configuration is not None:
self.use_quaternions = configuration.get('use quaternions', False)
self.port_handles = []
self.time_stamps = []
self.frame_numbers = []
self.qualities = []
self.rvec_rolling_means = []
self.tvec_rolling_means = []
if tracked_objects is not None:
for tracked_object in tracked_objects:
self.port_handles.append(tracked_object.name)
self.time_stamps.append(RollingMean(1, self.buffer_size,
datatype = float))
#let's not average the frame numbers, set buffer = 1
self.frame_numbers.append(RollingMean(1, 1,
datatype = int))
self.qualities.append(RollingMean(1, self.buffer_size,
datatype = float))
self.rvec_rolling_means.append(
RollingMeanRotation(self.buffer_size))
self.tvec_rolling_means.append(RollingMean(3, self.buffer_size))
[docs] def get_smooth_frame(self, port_handles):
"""
Called by derived classes to return smoothed data
:param port_handles: a list of port handles to get data for
:returns:
port_numbers : list of port handles, one per tool
time_stamps : list of timestamps (cpu clock), one per tool
frame_numbers : list of framenumbers (tracker clock) one per tool
tracking : list of 4x4 tracking matrices, rotation and position,
or if use_quaternions is true, a list of tracking quaternions,
column 0-2 is x,y,z column 3-6 is the rotation as a quaternion.
tracking_quality : list the tracking quality, one per tool.
"""
smth_handles = []
smth_times = []
smth_frame_nos = []
smth_tracking = []
smth_qual = []
for port_handle in port_handles:
try:
my_index = self.port_handles.index(port_handle)
smth_handles.append(port_handle)
smth_times.append(self.time_stamps[my_index].getmean()[0])
smth_frame_nos.append(self.frame_numbers[my_index].getmean()[0])
mean_quat = self.rvec_rolling_means[my_index].getmean()
mean_tvec = self.tvec_rolling_means[my_index].getmean()
smth_qual.append(self.qualities[my_index].getmean()[0])
if self.use_quaternions:
output_matrix = np.full((7,1), np.nan)
output_matrix[0:3,0] = mean_tvec
output_matrix[3:7,0] = mean_quat
smth_tracking.append(output_matrix)
else:
output_matrix = np.identity(4, dtype=np.float64)
output_matrix[0:3, 0:3] = quaternion_to_matrix(mean_quat)
output_matrix[0:3, 3] = mean_tvec
smth_tracking.append(output_matrix)
except ValueError:
raise ValueError(str(port_handle) + " not found in tracking " +
"buffers, did you call smooth_tracking " +
"before add_frame?") from ValueError
assert len(smth_handles) == len(port_handles)
return smth_handles, smth_times, smth_frame_nos, smth_tracking, \
smth_qual
[docs] def add_frame_to_buffer(self, port_handles, time_stamps, frame_numbers,
tracking_rot, tracking_trans, quality,
rot_is_quaternion = False):
"""
Called by derived classes to add data to the smoothing buffers
:param port_handles: a list of port handles
:param time_stamps: list of timestamps (cpu clock), one per tool
:param frame_numbers: list of framenumbers (tracker clock) one per tool
:param tracking_rot: list of tracking rotations
:param tracking_trans: list of tracking translations
:param quality: list the tracking quality, one per tool.
:param rot_is_quaternion: True if rotation is a quaternion.
"""
for your_index, port_handle in enumerate(port_handles):
my_index = None
try:
my_index = self.port_handles.index(port_handle)
except ValueError:
self.port_handles.append(port_handle)
self.time_stamps.append(RollingMean(1, self.buffer_size,
datatype = float))
self.frame_numbers.append(RollingMean(1, 1,
datatype = float))
self.qualities.append(RollingMean(1, self.buffer_size,
datatype = float))
self.rvec_rolling_means.append(
RollingMeanRotation(self.buffer_size))
self.tvec_rolling_means.append(RollingMean(3, self.buffer_size))
my_index = self.port_handles.index(port_handle)
assert my_index >= 0
self.time_stamps[my_index].pop(time_stamps[your_index])
self.frame_numbers[my_index].pop(frame_numbers[your_index])
self.rvec_rolling_means[my_index].pop(tracking_rot[your_index],
rot_is_quaternion)
self.tvec_rolling_means[my_index].pop(tracking_trans[your_index])
self.qualities[my_index].pop(quality[your_index])
[docs] @abstractmethod
def close(self):
"""Closes the connection to the tracker"""
raise NotImplementedError
[docs] @abstractmethod
def get_frame(self):
"""
:return:
port_numbers : list of port handles, one per tool
time_stamps : list of timestamps (cpu clock), one per tool
frame_numbers : list of framenumbers (tracker clock) one per tool
tracking : list of 4x4 tracking matrices, rotation and position,
or if use_quaternions is true, a list of tracking quaternions,
column 0-2 is x,y,z column 3-6 is the rotation as a quaternion.
tracking_quality : list the tracking quality, one per tool.
"""
raise NotImplementedError
[docs] @abstractmethod
def start_tracking(self):
"""
Tells the tracker to start tracking.
"""
raise NotImplementedError
[docs] @abstractmethod
def stop_tracking(self):
"""
Tells the tracker to stop tracking.
"""
raise NotImplementedError