Source code for xmipp3.protocols.protocol_classes_2d_mapping

# ******************************************************************************
# *
# * Authors:     David Herreros Calero (dherreros@cnb.csic.es)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# ******************************************************************************


import os.path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.widgets import Slider, RectangleSelector, Button
from matplotlib.cm import get_cmap, ScalarMappable

from pyworkflow.object import Float
from pyworkflow.utils.properties import Message
from pyworkflow.gui.dialog import askYesNo
import pyworkflow.protocol.params as param

from pwem.emlib.image import ImageHandler
import pwem.emlib.metadata as md
from pwem.protocols import ProtAnalysis2D
from pwem.objects import SetOfClasses2D

from xmipp3.convert import xmippToLocation, writeSetOfClasses2D


[docs]class XmippProtCL2DMap(ProtAnalysis2D): """ Create a low dimensional mapping from a SetOfClasses2D with interactive selection of classes. Use mouse left-click to select/deselect classes individually or mouse right-click to select/deselect several classes. AI Generated: This protocol helps you explore and curate your 2D classes visually. It takes a set of 2D class averages and places them in a 2D map where similar classes appear close to each other. You can then interactively select the classes you want to keep and export them as a new set. Think of it as a visual “landscape” of your 2D classes that helps you: - Identify groups of similar views - Detect outliers or junk classes - Separate different conformations or compositions - Select coherent subsets for downstream processing When should you use this? You typically use this protocol after running 2D classification (e.g., CL2D, Relion 2D, etc.) when: - You have many classes and it’s hard to interpret them one by one. - You suspect the presence of heterogeneous conformations. - You want to quickly identify junk clusters or rare views. - You want to define a clean subset of classes for 3D refinement. - Instead of manually browsing dozens or hundreds of class averages in a grid, this protocol shows them in a meaningful spatial organization. What does the map represent? Each point in the map corresponds to one 2D class average. Classes that look similar (same view, similar signal) tend to appear near each other. Classes that are very different (different view, contamination, broken particles) tend to be far apart. The map can optionally color classes according to their occupancy (how many particles they contain). This makes it much easier to see structure in your classification results. How to use the interactive selection If interactive mode is enabled (default): - A window opens showing the 2D map. - Each class average appears as a small image at its position. You can: - Left-click on a class to select or deselect it. - Right-click and drag to select/deselect multiple classes at once. - Use the zoom slider to change thumbnail size. - Use Select all / Select none buttons. - Selected classes are highlighted in green. - When you close the window, you will be asked whether you want to save the selection. If you confirm, a new set of 2D classes is created containing only the selected classes. If interactive selection is disabled, all classes are selected automatically. Choosing the mapping method You can choose different dimensionality reduction methods (PCA, Diffusion Maps, etc.) and distance metrics (Euclidean or Correlation). For most users: PCA + Correlation is a good default. If the map looks messy or unstructured, you can try another method (e.g., Diffusion Maps) to see if separation improves. You don’t need to understand the mathematics to use this effectively—just check whether the map produces meaningful grouping. What do you get as output? You get a new SetOfClasses2D containing only the classes you selected. This new set can be used for: - Exporting selected classes - Creating a clean subset of particles - Feeding into 3D reconstruction - Further classification steps The original classes remain untouched. Practical workflow example Run 2D classification. Use this protocol to map the classes. In the map: - Select the large, coherent cluster of good-looking classes. - Deselect obvious junk or tiny, isolated classes. - Export the selection. - Use the selected classes (or their particles) for 3D refinement. This approach is especially powerful when your dataset contains subtle heterogeneity that is not obvious from simple grid browsing. Why this is useful Manual selection of classes can be subjective and tedious. The 2D map: - Gives a global overview of your classification. - Makes outliers immediately visible. - Helps detect continuous variability. - Speeds up dataset cleaning. In short, it turns class inspection from a linear browsing task into a structured, intuitive visual decision process. """ _label = '2D classes mapping' red_methods = ['PCA', 'LTSA', 'DM', 'LLTSA', 'LPP', 'kPCA', 'pPCA', 'LE', 'HLLE', 'SPE', 'NPE'] distances = ['Euclidean', 'Correlation'] CLASSES_MD = 'classes.xmd' def __init__(self, **args): ProtAnalysis2D.__init__(self, **args) #--------------------------- DEFINE param functions ------------------------ def _defineParams(self, form): form.addSection(label='Input') form.addParam('inputClasses', param.PointerParam, label="Input 2D classes", important=True, pointerClass='SetOfClasses2D', help='Select the input classes to be mapped.') form.addParam('intSel', param.BooleanParam, default=True, label="Interactive class selection?") form.addSection(label='Mapping') form.addParam('method', param.EnumParam, choices=self.red_methods, default=0, label='Dimension reduction method') form.addParam('distance', param.EnumParam, choices=self.distances, default=1, label='Distance metric to compare images') #--------------------------- INSERT steps functions ------------------------ def _insertAllSteps(self): self._insertFunctionStep('convertStep') self._insertFunctionStep('computeMappingStep') self._insertFunctionStep('interactiveSelStep', interactive=True) #--------------------------- STEPS functions -------------------------------
[docs] def convertStep(self): metadata = self._getExtraPath(self.CLASSES_MD) writeSetOfClasses2D(self.inputClasses.get(), metadata, writeParticles=True)
[docs] def computeMappingStep(self): metadata = self._getExtraPath(self.CLASSES_MD) params = dict(metadata=metadata, method=self.red_methods[self.method.get()], metric=self.distances[self.distance.get()]) args = '-i classes@{metadata} -m {method} --distance {metric}'.format(**params) self.runJob("xmipp_transform_dimred", args)
[docs] def interactiveSelStep(self): selection_file = self._getExtraPath('selected_ids.txt') metadata = self._getExtraPath(self.CLASSES_MD) self.classes = self.inputClasses.get() self.mdOut = md.MetaData(metadata) if self.intSel.get(): img_paths = np.unique(np.asarray([rep.getFileName() for rep in self.classes.iterRepresentatives()])) pos = self.mdOut.getColumnValues(md.MDL_DIMRED) img_ids = self.mdOut.getColumnValues(md.MDL_REF) occupancy = self.mdOut.getColumnValues(md.MDL_CLASS_COUNT) occupancy = np.asarray(occupancy) pos = np.vstack(pos) # Read selected ids from previous runs (if they exist) if os.path.isfile(selection_file): self.selection = np.loadtxt(selection_file) self.selection = [int(self.selection)] if self.selection.size == 1 else \ self.selection.astype(int).tolist() else: self.selection = None view = ScatterImageMarker(pos=pos, img_paths=img_paths, ids=img_ids, occupancy=occupancy, prevsel=self.selection) view.initializePlot() self.selection = view.selected_ids # Save selected ids for interactive mode np.savetxt(selection_file, np.asarray(self.selection)) if askYesNo(Message.TITLE_SAVE_OUTPUT, Message.LABEL_SAVE_OUTPUT, None): self._createOutputStep() else: self.selection = [self.mdOut.getValue(md.MDL_REF, row.getObjId()) for row in md.iterRows(self.mdOut)] self._createOutputStep()
def _createOutputStep(self): self._loadClassesInfo() suffix = self._getOutputSuffix() selected_classes = self._createSetOfClasses2D(self.classes.getImages(), suffix=suffix) selected_classes.copyInfo(self.classes) selected_classes.appendFromClasses(self.classes, updateClassCallback=self._updateClass, filterClassFunc=self.isSelected) result = {'selectedClasses2D_' + suffix: selected_classes} self._defineOutputs(**result) self._defineSourceRelation(self.inputClasses, selected_classes)
[docs] def isSelected(self, cls): if cls.getObjId() in self.selection: return True else: return False
def _updateClass(self, item): classId = item.getObjId() index, fn, _, c = self._classesInfo[classId] item.setAlignment2D() rep = item.getRepresentative() rep.setLocation(index, fn) rep.setSamplingRate(self.inputClasses.get().getSamplingRate()) rep._c_x = Float(c[0]) rep._c_y = Float(c[1]) def _loadClassesInfo(self): """ Read some information about the produced 2D classes from the metadata file. """ self._classesInfo = {} # store classes info, indexed by class id for classNumber, row in enumerate(md.iterRows(self.mdOut)): index, fn = xmippToLocation(row.getValue(md.MDL_IMAGE)) c = row.getValue(md.MDL_DIMRED) # Store info indexed by id, we need to store the row.clone() since # the same reference is used for iteration self._classesInfo[classNumber + 1] = (index, fn, row.clone(), c) #--------------------------- INFO functions -------------------------------- def _summary(self): summary = [] if self.getOutputsSize() >= 1: totalClasses = self.inputClasses.get().getSize() for key, outClasses in self.iterOutputAttributes(): summary.append("*Output %s*" % key.split('_')[-1]) summary.append("A total of %d classes out of %d were selected" % (outClasses.getSize(), totalClasses)) else: summary.append("No classes selected yet.") return summary def _methods(self): methods = [] methods.append("*Dimensionality reduction method:* %s" % self.red_methods[self.method.get()]) methods.append("*Classes comparison metric:* %s" % self.distances[self.distance.get()]) return methods #--------------------------- UTILS functions ------------------------------- def _getOutputSuffix(self): maxCounter = -1 for attrName, _ in self.iterOutputAttributes(SetOfClasses2D): suffix = attrName.replace('selectedClasses2D_', '') try: counter = int(suffix) except ValueError: counter = 1 # when there is not number assume 1 maxCounter = max(counter, maxCounter) return str(maxCounter+1) if maxCounter > 0 else '1' # empty if not outputs
[docs]class ScatterImageMarker(object): def __init__(self, img_paths, pos, ids, occupancy=None, prevsel=None): self.running = True self.ids = ids self.occupancy = occupancy / np.sum(occupancy) if occupancy is not None else None self.selected_ids = prevsel if prevsel is not None else [] self.pos = pos self.readImages(img_paths) self.zoom = 30 / np.amax(self.images[0].shape) self.pad = 0.6 self.artists = [] with plt.style.context('seaborn-darkgrid'): self.fig, self.ax = plt.subplots(figsize=(10, 8)) self.fig.patch.set_facecolor('whitesmoke') lim_x_low, lim_x_high = -1.5 * np.amax(np.abs(self.pos[:, 0])), 1.5 * np.amax(np.abs(self.pos[:, 0])) lim_y_low, lim_y_high = -1.5 * np.amax(np.abs(self.pos[:, 1])), 1.5 * np.amax(np.abs(self.pos[:, 1])) self.ax.set_xlim(lim_x_low, lim_x_high) self.ax.set_ylim(lim_y_low, lim_y_high) plt.title('Interactive class selector', fontweight="bold", fontsize=15) plt.setp(self.ax.get_yticklabels(), fontweight="bold") plt.setp(self.ax.get_xticklabels(), fontweight="bold") plt.rcParams["font.weight"] = "bold"
[docs] def readImages(self, img_paths): if len(img_paths) == 1: # Numpy does not convert a list of size one to an array but to the actual object type of # element inside the list. We need to consider this case as it will raise an error in Python # when iterating over a non-array object type self.images = np.squeeze(ImageHandler().read(img_paths[0]).getData()) else: self.images = [np.squeeze(ImageHandler().read(img_path).getData()) for img_path in img_paths]
[docs] def imScatter(self, image, x, y, imid, edge_color=None): image = OffsetImage(image, zoom=self.zoom, cmap=plt.cm.gray) x, y = np.atleast_1d(x, y) for x0, y0 in zip(x, y): ab = AnnotationBbox(image, (x0, y0), xycoords='data', frameon=True) ab.patch.set_boxstyle("Round, pad={}".format(self.pad)) if imid in self.selected_ids: ab.patch.set_facecolor('palegreen') else: ab.patch.set_facecolor("lightgray") ab.patch.set_alpha(0.5) ab.patch.set_linewidth(2) if edge_color is not None: ab.patch.set_edgecolor(edge_color) elif edge_color is None and imid in self.selected_ids: ab.patch.set_edgecolor('darkgreen') self.artists.append(self.ax.add_artist(ab))
[docs] def is_window_closed(self, event): self.running = False
[docs] def updatePatch(self, ind): selected_artist = self.artists[ind] selected_id = self.ids[ind] patch = selected_artist.patch if selected_id in self.selected_ids: self.selected_ids.remove(selected_id) patch.set_facecolor("lightgray") if self.occupancy is None: patch.set_edgecolor('black') else: self.selected_ids.append(selected_id) patch = self.artists[ind].patch patch.set_facecolor('palegreen') if self.occupancy is None: patch.set_edgecolor('darkgreen') self.fig.canvas.draw_idle()
[docs] def setPickCallback(self): # Picking callback def onPickImage(event): ind = event.ind[0] self.updatePatch(ind) self.fig.canvas.mpl_connect('pick_event', onPickImage)
[docs] def setSliderCallback(self): # Slider callback slider = plt.axes([0.28, 0.01, 0.3, 0.03], facecolor='lavender') if self.occupancy is not None else \ plt.axes([0.355, 0.01, 0.3, 0.03], facecolor='lavender') self.size_slider = Slider(ax=slider, label='Image zoom', valstep=0.1, valmin=0, valmax=2, valinit=self.zoom, color='springgreen') def updateSize(val): self.pad = val * self.pad / self.zoom self.zoom = val for image, artist in zip(self.images, self.artists): image = OffsetImage(image, zoom=self.zoom, cmap=plt.cm.gray) artist.offsetbox = image artist.patch.set_boxstyle("Round, pad={}".format(self.pad)) self.fig.canvas.draw_idle() self.size_slider.on_changed(updateSize)
[docs] def setRectangleSelector(self): # Rectangle Selector def images_select(eclick, erelease): x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata x_sorted, y_sorted = [x1, x2], [y1, y2] x_sorted.sort() y_sorted.sort() x_inrange = np.logical_and(self.pos[:, 0] >= x_sorted[0], self.pos[:, 0] <= x_sorted[1]) y_inrange = np.logical_and(self.pos[:, 1] >= y_sorted[0], self.pos[:, 1] <= y_sorted[1]) ids_inrange = np.argwhere(np.logical_and(x_inrange, y_inrange)).flatten() _, = [self.updatePatch(ind) for ind in ids_inrange] def toggle_selector(event): """ This method must be left empty. I will be completed afterwards with the object to be modifed during a callback as required by Matplotlib. """ pass rectprops = dict(facecolor='cyan', edgecolor='gray', alpha=0.2, fill=True) toggle_selector.RS = RectangleSelector(self.ax, images_select, drawtype='box', useblit=True, button=3, # use only right click minspanx=5, minspany=5, spancoords='data', rectprops=rectprops, interactive=False) self.fig.canvas.mpl_connect('key_press_event', toggle_selector)
[docs] def setSelectionButtons(self): # Selection Buttons def selectAll(event): self.selected_ids = self.ids.copy() for ind in self.ids: patch = self.artists[ind].patch patch.set_facecolor('palegreen') if self.occupancy is None: patch.set_edgecolor('darkgreen') self.fig.canvas.draw_idle() def selectNone(event): self.selected_ids = [] for ind in self.ids: patch = self.artists[ind].patch patch.set_facecolor("lightgray") if self.occupancy is None: patch.set_edgecolor('black') self.fig.canvas.draw_idle() axprev = plt.axes([0.65, 0.015, 0.15, 0.035], facecolor='lavender') axnext = plt.axes([0.81, 0.015, 0.15, 0.035], facecolor='lavender') self.bnall = Button(axnext, 'Select all') self.bnall.on_clicked(selectAll) self.bnone = Button(axprev, 'Select none') self.bnone.on_clicked(selectNone)
[docs] def initializePlot(self): if self.occupancy is not None: cmap = get_cmap('cool') # Iterate over all class images and create the matplotlib representation that will be used # to render that image in the scatter plot point position (we move the resulting empty list # to a junk variable) _ = [self.imScatter(image, x, y, imid, cmap(occupancy)) for x, y, image, occupancy, imid in zip(self.pos[:, 0], self.pos[:, 1], self.images, self.occupancy, self.ids)] cb = self.fig.colorbar(ScalarMappable(cmap=cmap), ax=self.ax, extend='both') cb.set_label("Class occupancy", fontweight="bold", labelpad=15) else: # Iterate over all class images and create the matplotlib representation that will be used # to render that image in the scatter plot point position (we move the resulting empty list # to a junk variable) _ = [self.imScatter(image, x, y, imid) for x, y, image, imid in zip(self.pos[:, 0], self.pos[:, 1], self.images, self.ids)] self.ax.scatter(self.pos[:, 0], self.pos[:, 1], alpha=0, picker=True) # Probably set pickradius param? # Callback initialization self.setPickCallback() self.setSliderCallback() self.setRectangleSelector() self.setSelectionButtons() # Wait until interactive plot is closed self.fig.canvas.mpl_connect('close_event', self.is_window_closed) while self.running: plt.pause(.001)