# **************************************************************************
# *
# * Authors: David Herreros Calero (dherreros@cnb.csic.es)
# *
# *
# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
import numpy as np
import tkinter as tk
from matplotlib.widgets import RadioButtons, Slider
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from pyworkflow.gui.plotter import plt
from ...emlib.image import ImageHandler
from .callbacks import DraggablePoint
[docs]class MaskVolumeWizard(object):
'''Create a mask for a volume interactively. Masks currently implemented:
- Spherical mask
'''
def __init__(self, filename):
volume = ImageHandler().read(filename).getData()
self.volume = np.squeeze(np.copy(volume))
self.coords = None
self.coordsDownsampled = None
self.xmippOrigin = np.array((self.volume.shape[0] / 2,
self.volume.shape[1] / 2,
self.volume.shape[2] / 2))
self.origin = np.array((0, 0, 0)) # (z,y,x)
self.pressed = False
self.radio = None
self.cb = None
self.sphere_artist = None
self.radius = 0
self.running = True
self.root = tk.Tk()
self.root.resizable(False, False) # For matplotlib <= 3.3.x
self.fig = plt.Figure(figsize=plt.figaspect(1)*1.5)
self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
self.ax_3d = self.fig.add_subplot(projection='3d')
# plt.style.use('dark_background')
[docs] def get_sphere_params(self, mass_center=True):
"""Return x,y,z,r array. Two possible conventions are implemented:
- Bottom Left Corner (All Coordinates positive)
- Center of Mass (Coordinates can be positive or negative) -- Default
"""
if mass_center:
# Center of mass convention - Xmipp convention (Coordinates can be positive or negative)
origin = self.origin
else:
# Bottom Left convention (all coordinates positive)
origin = self.origin + self.xmippOrigin.astype(int)
return np.hstack([origin[2],
origin[1],
origin[0],
self.radius])
[docs] def is_window_closed(self):
self.running = False
[docs] def set_axes_equal(self, ax: plt.Axes):
"""Set 3D plot axes to equal scale.
Make axes of 3D plot have equal scale so that spheres appear as
spheres and cubes as cubes. Required since `ax.axis('equal')`
and `ax.set_aspect('equal')` don't work on 3D.
"""
limits = np.array([
ax.get_xlim3d(),
ax.get_ylim3d(),
ax.get_zlim3d(),
])
origin = np.mean(limits, axis=1)
radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0]))
self._set_axes_radius(ax, origin, radius)
def _set_axes_radius(self, ax, origin, radius):
x, y, z = origin
ax.set_xlim3d([x - radius, x + radius])
ax.set_ylim3d([y - radius, y + radius])
ax.set_zlim3d([z - radius, z + radius])
[docs] def plotScatter(self):
self.ax_3d.clear()
xi = self.coordsDownsampled[:, 0]
yi = self.coordsDownsampled[:, 1]
zi = self.coordsDownsampled[:, 2]
ori_x, ori_y, ori_z = self.origin[0], self.origin[1], self.origin[2]
plt.ion()
self.ax_3d.scatter(xi, yi, zi, s=12, c='purple', edgecolors='k', alpha=0.3)
self.M = [-self.ax_3d.azim * np.pi / 180, self.ax_3d.elev * np.pi / 180, 0]
scatter_origin = self.ax_3d.scatter(ori_x, ori_y, ori_z, s=100, c='cyan', edgecolors='k')
self.plot_sphere(self.radius)
self.dr = DraggablePoint(self.origin, self.fig, self.ax_3d, scatter_origin, self.M)
self.ax_3d.set_axis_off()
# self.ax_3d.set_box_aspect([1, 1, 1]) # For matplotlib => 3.3.x
self.set_axes_equal(self.ax_3d)
[docs] def plot_sphere(self, radius):
self.radius = radius
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = radius * np.outer(np.cos(u), np.sin(v)) + self.origin[0]
y = radius * np.outer(np.sin(u), np.sin(v)) + self.origin[1]
z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + self.origin[2]
if self.sphere_artist == None:
self.sphere_artist = self.ax_3d.plot_surface(x, y, z, rstride=4, cstride=4, color='r',
linewidth=0, alpha=0.1)
else:
self.sphere_artist.remove()
self.sphere_artist = self.ax_3d.plot_surface(x, y, z, rstride=4, cstride=4, color='r',
linewidth=0, alpha=0.1)
[docs] def changeThreshold(self, threshold):
self.coords = np.argwhere(self.volume >= threshold)
self.coordsDownsampled = np.copy(self.coords)
if hasattr(self, "slider_thr"):
self.downsamplingPC(self.slider.val)
else:
self.downsamplingPC(5.01)
[docs] def downsamplingPC(self, voxel_size):
mode = 'barycenter'
non_empty_voxel_keys, inverse, nb_pts_per_voxel = np.unique(((self.coords - np.min(self.coords, axis=0))
// voxel_size).astype(int), axis=0,
return_inverse=True,
return_counts=True)
idx_pts_vox_sorted = np.argsort(inverse)
voxel_grid = {}
grid_barycenter, grid_candidate_center = [], []
last_seen = 0
for idx, vox in enumerate(non_empty_voxel_keys):
voxel_grid[tuple(vox)] = self.coords[idx_pts_vox_sorted[last_seen:last_seen + nb_pts_per_voxel[idx]]]
grid_barycenter.append(np.mean(voxel_grid[tuple(vox)], axis=0))
last_seen += nb_pts_per_voxel[idx]
if mode == 'barycenter':
self.coordsDownsampled = np.asarray(grid_barycenter)
self.remove_outliers()
self.coordsDownsampled = self.coordsDownsampled - self.xmippOrigin
self.plotScatter()
[docs] def remove_outliers(self):
threshold = 2
mean_coords = np.mean(self.coordsDownsampled, axis=0)
std_coords = np.std(self.coordsDownsampled, axis=0)
z_score = np.mean((self.coordsDownsampled - mean_coords) / std_coords, axis=1)
self.coordsDownsampled = self.coordsDownsampled[np.squeeze(np.argwhere(z_score <= threshold))]
[docs] def press_shift(self, event):
if event.key == 'shift':
self.pressed = True
self.dr.connect()
[docs] def release_shift(self, event):
if self.pressed and event.key == 'shift':
self.pressed = False
self.dr.disconnect()
self.fig.canvas.mpl_connect('button_release_event', self.on_release)
self.origin = self.dr.point
self.plot_sphere(self.radius)
self.fig.canvas.draw()
[docs] def on_release(self, event):
self.M = [-self.ax_3d.azim * np.pi / 180, self.ax_3d.elev * np.pi / 180, 0]
self.dr.M = self.M
[docs] def change_view(self, event):
if event == "X":
self.ax_3d.view_init(elev=90., azim=0.)
self.M = [0, np.pi / 2., 0]
self.dr.M = self.M
elif event == "Y":
self.ax_3d.view_init(elev=0., azim=90.)
self.M = [-np.pi / 2., 0., 0]
self.dr.M = self.M
elif event == "Z":
self.ax_3d.view_init(elev=0., azim=0.)
self.M = [-0., 0., 0]
self.dr.M = self.M
self.fig.canvas.draw()
[docs] def initializePlot(self):
self.changeThreshold(0.5 * (np.amax(self.volume) - np.amin(self.volume)))
self.change_view("Z")
# Buttons
axcolor = 'grey'
# rax = self.fig.add_axes([0.1, 0.4, 0.12, 0.25], facecolor=axcolor) # For matplotlib => 3.3.x
rax = self.fig.add_axes([0.05, 0.5, 0.12, 0.15], facecolor=axcolor) # For matplotlib <= 3.3.x
self.radio = RadioButtons(rax, ('X', 'Y', 'Z'), activecolor='navy', active=2)
self.radio.on_clicked(self.change_view)
self.canvas.get_tk_widget().pack(side=tk.BOTTOM, fill=tk.BOTH, expand=1)
# Slider
stax = self.fig.add_axes([0.2, 0.02, 0.65, 0.03], facecolor=axcolor)
self.slider_thr = Slider(stax, 'Threshold', np.amin(self.volume),
np.amax(self.volume), valinit=0.5 * (np.amax(self.volume) - np.amin(self.volume)),
valstep=0.01 * (np.amax(self.volume) - np.amin(self.volume)), color='navy')
self.slider_thr.on_changed(self.changeThreshold)
srax = self.fig.add_axes([0.2, 0.06, 0.65, 0.03], facecolor=axcolor)
max_radius = self.volume.shape[0]
self.slider_radius = Slider(srax, 'Radius', 0, max_radius, valinit=0, valstep=1, color='navy')
self.slider_radius.on_changed(self.plot_sphere)
sax = self.fig.add_axes([0.2, 0.11, 0.65, 0.03], facecolor=axcolor)
self.slider = Slider(sax, 'Downsampling', 0.01, 10, valinit=5.01, valstep=0.2, color='navy')
self.slider.on_changed(self.downsamplingPC)
# Toolbar
toolbar = NavigationToolbar2Tk(self.canvas, self.root)
toolbar.update()
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
self.canvas._tkcanvas.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=1)
self.root.protocol("WM_DELETE_WINDOW", self.is_window_closed)
self.canvas.mpl_connect('key_press_event', self.press_shift)
self.canvas.mpl_connect('key_release_event', self.release_shift)
self.canvas.mpl_connect('button_release_event', self.on_release)
# GUI Running Loop
while self.running:
self.root.update_idletasks()
self.root.update()
self.root.destroy()