Source code for xmipp3.viewers.viewer_structure_map

# **************************************************************************
# *
# * Authors:     David Herreros Calero (dherreros@cnb.csic.es)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

import os, matplotlib, math
from scipy import ndimage
from scipy.spatial import KDTree
import tkinter as tk
#matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.widgets import RadioButtons
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from mpl_toolkits.mplot3d.proj3d import proj_transform
from matplotlib.text import Annotation

import numpy as np
from scipy.ndimage.filters import gaussian_filter

from pyworkflow.viewer import DESKTOP_TKINTER, WEB_DJANGO, ProtocolViewer
from pyworkflow.gui.plotter import plt
import pyworkflow.protocol.params as params

from xmipp3.protocols.protocol_structure_map import XmippProtStructureMap
from xmipp3.protocols.protocol_structure_map_zernike3d import XmippProtStructureMapZernike3D


[docs]class XmippProtStructureMapViewer(ProtocolViewer): """ Wrapper to visualize different type of data objects with the Xmipp program xmipp_showj """ _label = 'viewer structure map' _environments = [DESKTOP_TKINTER, WEB_DJANGO] _targets = [XmippProtStructureMap, XmippProtStructureMapZernike3D] def _defineParams(self, form): form.addSection(label='Show StructMap') if isinstance(self.protocol, XmippProtStructureMapZernike3D): form.addParam('map', params.EnumParam, choices=['Deformation', 'Correlation', 'Consensus'], default=0, help='Choose the type of metric to display the coordinates.') form.addParam('twoSets', params.BooleanParam, default=False, label='Two set analysis', help='Activate if the analysis of the deformation included the option of analysing ' 'two sets of volumes independently.') form.addParam('doShowPlot', params.LabelParam, label="Display the StructMap") def _getVisualizeDict(self): return {'doShowPlot': self._visualize}
[docs] def getOutputFile(self): fnOutput = [''] if isinstance(self.protocol, XmippProtStructureMapZernike3D): if self.map.get() == 0 and not self.twoSets: fnOutput = [self.protocol._defineResultsName(3)] elif self.map.get() == 1 and not self.twoSets: fnOutput = [self.protocol._defineResultsName2(3)] elif self.map.get() == 0 and self.twoSets: fnOutput = [self.protocol._defineResultsName(3, 'Sub_1_'), self.protocol._defineResultsName(3, 'Sub_2_')] elif self.map.get() == 1 and self.twoSets: fnOutput = [self.protocol._defineResultsName2(3, 'Sub_1_'), self.protocol._defineResultsName2(3, 'Sub_2_')] elif self.map.get() == 2 and not self.twoSets: fnOutput = [self.protocol._defineResultsName3(3)] else: fnOutput = [self.protocol._defineResultsName(3)] return fnOutput
def _visualize(self, e=None): fnOutput = self.getOutputFile() for file in fnOutput: if not os.path.exists(file): return [self.errorMessage('The necessary metadata was not produced\n' 'Execute again the protocol\n', title='Missing result file')] self.coordinates = (np.loadtxt(file) for file in fnOutput) self.coordinates = np.vstack(self.coordinates) labels = [v.getObjLabel() for v in self.protocol.inputVolumes.get()] if '' in labels: labels = [str(idp) for idp in range(1, self.coordinates.shape[0] + 1)] if os.path.isfile(self._getExtraPath('weigths.txt')): weights = np.loadtxt(self._getExtraPath('weigths.txt')) else: weights = None plot = projectionPlot(self.coordinates, labels, weights) plot.initializePlot() return plot def _validate(self): errors = [] return errors
[docs]class projectionPlot(object): def __init__(self, coords, labels, weights): self.coords = coords self.labels = labels self.weights = weights try: self.minimum_spanning_tree() except IndexError: self.T = None self.proj_coords = None self.radio = None self.cb = None self.prevlabel = 'Scatter' self.root = tk.Tk() self.fig = plt.Figure(figsize=(10, 4)) self.canvas = FigureCanvasTkAgg(self.fig, master=self.root) self.ax_3d = self.fig.add_subplot(121, projection="3d") self.ax_3d.set_title("3D view") self.ax_3d.set_axis_off() self.fig.canvas.mpl_connect('button_release_event', self.onRelease) self.ax_2d = self.fig.add_subplot(122) self.ax_2d.set_title("Projection from current 3D view")
[docs] def minimum_spanning_tree(self): N = self.coords.shape[0] tree = KDTree(self.coords) distances, indices = tree.query(self.coords, k=10) nn_mat = np.zeros((N, N)) for idn in range(N): nn_mat[idn, indices[idn]] += distances[idn].reshape(-1) nn_mat[idn, idn] = 0 edges = ((int(e[0]), int(e[1])) for e in zip(*np.asarray(nn_mat).nonzero())) triples = list(((u, v, float(nn_mat[u, v])) for u, v in edges)) edges = [(triple[0], triple[1]) for triple in triples] weigths_edge = [triple[2] for triple in triples] edge_matrix = np.zeros((N, N)) for edge, weigth in zip(edges, weigths_edge): edge_matrix[edge] = weigth self.T = self.KruskalMST(triples, N)
[docs] def KruskalMST(self, triples, N): result = [] i = 0 e = 0 graph = sorted(triples, key=lambda item: item[2]) parent = [] rank = [] for node in range(N): parent.append(node) rank.append(0) while e < N - 1: u, v, w = graph[i] i = i + 1 x = self.find(parent, u) y = self.find(parent, v) if x != y: e = e + 1 result.append([u, v, w]) self.union(parent, rank, x, y) return result
[docs] def find(self, parent, i): if parent[i] == i: return i return self.find(parent, parent[i])
[docs] def union(self, parent, rank, x, y): xroot = self.find(parent, x) yroot = self.find(parent, y) if rank[xroot] < rank[yroot]: parent[xroot] = yroot elif rank[xroot] > rank[yroot]: parent[yroot] = xroot else: parent[yroot] = xroot rank[xroot] += 1
[docs] def mst_3D(self): N = self.coords.shape[0] degree = [0] * N edges_mst = [0] * (N - 1) if self.T is not None: for idn in range(N - 1): degree[self.T[idn][0]] += 1 degree[self.T[idn][1]] += 1 edges_mst[idn] = (self.T[idn][0], self.T[idn][1]) edge_max = max(degree) colors = [plt.cm.plasma(val / edge_max) for val in degree] for edge in edges_mst: if edge != 0: x = np.array((self.coords[edge[0]][0], self.coords[edge[1]][0])) y = np.array((self.coords[edge[0]][1], self.coords[edge[1]][1])) z = np.array((self.coords[edge[0]][2], self.coords[edge[1]][2])) self.ax_3d.plot(x, y, z, c='black', alpha=0.5) else: colors = [plt.cm.plasma(1) for _ in self.coords] for idn, row in enumerate(self.coords): xi = row[0] yi = row[1] zi = row[2] self.ax_3d.scatter(xi, yi, zi, c=[colors[idn]], s=20 + 20 * degree[idn], edgecolors='k', alpha=0.7) annotate3D(self.ax_3d, s=self.labels, xyz=self.coords, fontsize=10, xytext=(-3, 3), textcoords='offset points', ha='right', va='bottom')
[docs] def projectMatrix(self, M, coords): proj_coords = [] for coord in coords: coord = np.append(coord, 1.0) proj_coord = M.dot(coord) proj_coords.append(proj_coord) return np.asarray(proj_coords)
[docs] def onRelease(self, event): if event.inaxes == self.ax_3d: M = event.inaxes.get_proj() self.proj_coords = self.projectMatrix(M, self.coords) self.plotType(self.prevlabel)
[docs] def plotScatter(self, x, y): self.ax_2d.clear() self.ax_2d.scatter(x, y, color="green") self.ax_2d.set_title("Projection Minimum Spanning Tree") self.fig.canvas.draw()
[docs] def plotContour(self, x, y): self.ax_2d.clear() rangeX = np.max(x) - np.min(x) rangeY = np.max(y) - np.min(y) if rangeX > rangeY: sigma = rangeX / 50 else: sigma = rangeY / 50 xi = np.linspace(min(x) - 0.1, max(x) + 0.1, 100) yi = np.linspace(min(x) - 0.1, max(x) + 0.1, 100) z = np.zeros((100, 100), float) zSize = z.shape N = len(x) for c in range(zSize[1]): for r in range(zSize[0]): for d in range(N): z[r, c] = z[r, c] + (1.0 / N) * (1.0 / ((2 * math.pi) * sigma ** 2)) * math.exp( -((xi[c] - x[d]) ** 2 + (yi[r] - y[d]) ** 2) / (2 * sigma ** 2)) zMax = np.max(z) z = z / zMax self.ax_2d.contour(xi, yi, z, 15, linewidths=0.5, colors='k') cf = self.ax_2d.contourf(xi, yi, z, 15, cmap=plt.cm.jet) self.ax_2d.set_title("Projection Scatter Plot") cbaxes = self.fig.add_axes([0.92, 0.1, 0.01, 0.8]) self.cb = self.fig.colorbar(mappable=cf, cax=cbaxes) self.cb.set_ticks([]) self.fig.canvas.draw()
[docs] def plotConvolution(self, x, y): coordinates = np.stack((x, y), axis=1) Xr = np.round(coordinates, decimals=3) size_grid = 2 * np.amax(Xr) grid_coords = np.arange(-size_grid, size_grid, 0.001) R, C = np.meshgrid(grid_coords, grid_coords, indexing='ij') S = np.zeros(R.shape) sigma = R.shape[0] / (200 / 5) lbox = int(6 * sigma) if lbox % 2 == 0: lbox += 1 mid = int((lbox - 1) / 2 + 1) kernel = np.zeros((lbox, lbox)) kernel[mid, mid] = 1 kernel = gaussian_filter(kernel, sigma=sigma) for p in range(Xr.shape[0]): indx = np.argmin(np.abs(R[:, 0] - Xr[p, 0])) indy = np.argmin(np.abs(C[0, :] - Xr[p, 1])) if 'weights' in locals(): S[indx - mid:indx + mid - 1, indy - mid:indy + mid - 1] += kernel * self.weights else: S[indx - mid:indx + mid - 1, indy - mid:indy + mid - 1] += kernel S = S[~np.all(S == 0, axis=1)] S = S[:, ~np.all(S == 0, axis=0)] S = ndimage.rotate(S, 90) cf = self.ax_2d.imshow(S) cbaxes = self.fig.add_axes([0.92, 0.1, 0.01, 0.8]) self.ax_2d.set_title('Projection Scatter Plot') self.cb = self.fig.colorbar(mappable=cf, cax=cbaxes) self.cb.set_ticks([]) self.fig.canvas.draw()
[docs] def plotType(self, label): self.prevlabel = label x = self.proj_coords[:, 0] y = self.proj_coords[:, 1] if self.cb != None: self.cb.remove() self.cb = None if label == 'Scatter': self.plotScatter(x, y) elif label == 'Contour': self.plotContour(x, y) elif label == 'Convolution': self.plotConvolution(x, y)
[docs] def initializePlot(self): self.mst_3D() M = self.ax_3d.get_proj() self.proj_coords = self.projectMatrix(M, self.coords) self.ax_2d.scatter(self.proj_coords[:, 0], self.proj_coords[:, 1], color="green") # Buttons axcolor = 'silver' rax = self.fig.add_axes([0.01, 0.4, 0.12, 0.25], facecolor=axcolor) self.radio = RadioButtons(rax, ('Scatter', 'Contour', 'Convolution')) self.radio.on_clicked(self.plotType) self.canvas.get_tk_widget().pack(side=tk.BOTTOM, fill=tk.BOTH, expand=1) self.canvas._tkcanvas.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=1) # Toolbar toolbar = NavigationToolbar2Tk(self.canvas, self.root) toolbar.update() self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1) tk.mainloop()
[docs]class Annotation3D(Annotation): '''Annotate the point xyz with text s''' def __init__(self, s, xyz, *args, **kwargs): Annotation.__init__(self, "", xy=(0, 0), bbox=dict(boxstyle="round,pad=0.3", fc="whitesmoke", ec="lightgray", lw=2), *args, **kwargs) self._verts3d = xyz self.s = s
[docs] def draw(self, renderer): for coord, text in zip(self._verts3d, self.s): xs3d, ys3d, zs3d = coord[0], coord[1], coord[2] xs, ys, zs = proj_transform(xs3d, ys3d, zs3d, renderer.M) self.xy = (xs, ys) Annotation.set_text(self, text) Annotation.draw(self, renderer)
[docs]def annotate3D(ax, s, *args, **kwargs): '''add anotation text s to to Axes3d ax''' tag = Annotation3D(s, *args, **kwargs) ax.add_artist(tag)