Source code for xmipp3.protocols.protocol_cl2d_clustering

# ******************************************************************************
# *
# * Authors:     Daniel Marchan Torres (da.marchan@cnb.csic.es)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# ******************************************************************************

import os.path
from pwem.objects.data import Class2D, Particle, SetOfClasses2D, SetOfAverages
from pwem.protocols import ProtAnalysis2D
from pyworkflow.protocol.params import (PointerParam, IntParam,
                                        EnumParam, LEVEL_ADVANCED, LT, GT)
from pyworkflow import NEW, BETA
from xmipp3.base import XmippProtocol

FN = "class_representatives"
RESULT_FILE = 'best_clusters_with_names.txt'
OUTPUT_CLASSES = 'outputClasses'
OUTPUT_AVERAGES = 'outputAverages'


[docs]class XmippProtCL2DClustering(ProtAnalysis2D, XmippProtocol): """Groups similar 2D class averages using clustering. This process helps identify homogeneous subsets within the dataset, improving classification and downstream analysis. AI Generated: This protocol takes a collection of 2D class representatives (either a SetOfClasses2D or a SetOfAverages) and groups them into a smaller number of clusters of similar-looking averages. The idea is not to re-align particles or to re-run 2D classification, but to organize the existing 2D summaries into homogeneous groups. This is particularly handy when you have many 2D classes and you want to quickly identify families (similar views, similar structural states, or recurring artifacts) and extract a concise set of “representative” images. What you provide as input You provide inputSet2D, which can be either: SetOfClasses2D: the protocol will use the representatives of each class (i.e., the class averages) as the images to cluster. SetOfAverages: the protocol will use those averages directly. Internally, it records the sampling rate from the input: for classes it reads it from the first representative; for averages it reads it from the averages set. How the protocol prepares data (convert step) Before clustering, the protocol exports the input representatives into a single stack: class_representatives.mrcs (stored under the protocol extra folder) At the same time, it writes a simple text file: class_representatives.txt This second file is just the list of the original reference IDs (the image indices / ids associated with each representative in the stack). That file is crucial later because the clustering program will work on the exported stack, and then Scipion needs a way to map “cluster members” back to the original class/average IDs. Choosing the search range for the number of clusters The clustering program searches for an “optimal” number of clusters, but you can constrain the search: Minimum number of clusters (min_cluster): by default 10. This is the lower bound of the scan; it prevents the program from collapsing everything into a very small number of groups too early. Maximum number of clusters (max_cluster): by default -1, meaning “use the internal default”, which here is intended to behave like N_classes − 2 (i.e., search up to almost the full number of images). If you set a smaller maximum, you restrict how fine-grained clustering is allowed to be. There is also a computational control: Number of computational threads (compute_threads): how many CPU threads the clustering executable uses (default 8). More threads typically means faster clustering. The clustering run (xmipp_cl2d_clustering) Once the representative stack exists, the protocol launches: xmipp_cl2d_clustering -i class_representatives.mrcs -o <extraDir> -m <min> -M <max> -j <threads> The program writes its results into the output directory, including a key text file: best_clusters_with_names.txt This file is the authoritative description of which original representatives ended up in each cluster. What you can ask the protocol to produce In the Output section you choose an “Extraction option”: Classes The protocol builds a new SetOfClasses2D where each output class corresponds to one cluster. Importantly, this output is only possible if the input is actually a SetOfClasses2D, because it needs to pull the underlying particles from each original class and regroup them. Averages The protocol extracts one “representative” per cluster and outputs a SetOfAverages. In this implementation the representative is simply the first class/average ID listed for that cluster in the results file (treated as the “centroid” by convention in the output text). So you get one average per cluster, which is a compact summary of the diversity in your original set. Both It creates the clustered SetOfClasses2D (clusters as classes) and also the SetOfAverages (one representative per cluster). How output objects are built (and what that implies) Output: SetOfAverages (cluster representatives) For each cluster, the protocol reads the list of member IDs from best_clusters_with_names.txt and takes the first one as the cluster representative. If the input was classes, it clones the representative image of that class. If the input was averages, it clones that average. It then creates a Particle object (used here as an image container) and sets its objId and classId to the chosen ID. Finally it stores all these into a SetOfAverages and sets the sampling rate. Practical implication: the “representative” is not a newly computed centroid; it is a selected existing average/class representative. Output: SetOfClasses2D (clusters as new classes) This is only allowed when the input is a SetOfClasses2D (the protocol explicitly validates this). For each cluster: It creates one new Class2D, using the first member’s class as the template (copying metadata/info). It then iterates over all particles inside each member class, reassigns their classId to the new cluster-class id, clones them, and accumulates them. After creating all new classes, it fills each class with its collected particles and writes the set. Practical implication: the resulting clustered classes contain the union of particles from all original classes that were grouped together. This is a re-grouping operation; it does not re-align or recompute averages inside this protocol. Output bookkeeping and summary The protocol sets a summary string of the form: “Classify original input set of N images into K groups…” where N is the number of input classes/averages and K is the number of clusters found (i.e., the number of “Cluster X” blocks parsed in the result file). It also defines source relations so provenance is preserved: clustered classes are linked to the original particle images pointer; cluster representatives are linked to the original input 2D set. Validation rules and common gotchas If you request Classes (or Both), the input must be a SetOfClasses2D. If your input is a SetOfAverages, you can only extract Averages (cluster representatives), because there are no underlying per-class particle memberships to regroup. The protocol relies on the format of best_clusters_with_names.txt (lines like Cluster 0: followed by numeric IDs). If that file is missing or malformed, output creation will fail because it cannot reconstruct the cluster dictionary. Extra visualization hooks The protocol exposes paths to two PNGs (presumably generated by the clustering executable) that a viewer can show: best_cluster_visualization_with_images.png all_clusters_with_labels.png They’re meant to help you visually assess cluster separation and membership at a glance. """ _label = 'clustering 2d classes' _devStatus = BETA _possibleOutputs = {OUTPUT_CLASSES: SetOfClasses2D, OUTPUT_AVERAGES: SetOfAverages} CLASSES = 0 AVERAGES = 1 BOTH = 2 def __init__(self, **args): ProtAnalysis2D.__init__(self, **args) #--------------------------- DEFINE param functions ------------------------ def _defineParams(self, form): form.addSection(label='Input') form.addParam('inputSet2D', PointerParam, label="Input 2D images", important=True, pointerClass='SetOfClasses2D, SetOfAverages', help='Select the input classes or input averages to be clustered.') form.addParam('min_cluster', IntParam, label='Minimum number of clusters', default=10, expertLevel=LEVEL_ADVANCED, validators=[GT(1, 'Error must be greater than 1')], help=' This number will limit the search for the optimum number of clusters. ' 'By default, the 2D averages will start searching for the optimum number of clusters' 'with a minimum number of 10 classes.') form.addParam('max_cluster', IntParam, label='Maximum number of clusters', default=-1, expertLevel=LEVEL_ADVANCED, validators=[LT(50, 'Error must be smaller than the number of classes - 2.')], help='This number will limit the search for the optimum number of clusters. ' 'If -1 then it will act as default. By default, the 2D averages will end searching' 'for the optimum number of clusters until a maximum number of N_classes - 2.') form.addParam('compute_threads', IntParam, label='Number of computational threads', default=8, expertLevel=LEVEL_ADVANCED, validators=[ GT(0, 'Error must be greater than 0.')], help=' By default, the program will use 8 threads for computation.' 'The higher the number the fastest the computation will be') form.addSection(label='Output') form.addParam('extractOption', EnumParam, choices=['Classes', 'Averages', 'Both'], default=self.CLASSES, label="Extraction option", display=EnumParam.DISPLAY_COMBO, help='Select an option to extract from the 2D Classes: \n ' '_Classes_: Create a new set of 2D classes with the respective cluster distribution. \n ' '_Averages_: Extract the representatives of each cluster. This are the most representative averages. \n' '_Both_: Create a new set of 2D classes and extract their representatives.') # --------------------------- INSERT steps functions ------------------------ def _insertAllSteps(self): convertStep = self._insertFunctionStep(self.convertStep) clusterStep = self._insertFunctionStep(self.clusterClasses, prerequisites=convertStep) self._insertFunctionStep(self.createOutputStep, prerequisites=clusterStep)
[docs] def convertStep(self): self.info('Writting class representatives') self.directoryPath = self._getExtraPath() self.imgsFn = os.path.join(self.directoryPath, FN + ".mrcs") self.refIdsFn = os.path.join(self.directoryPath, FN + ".txt") inputSet2D = self.inputSet2D.get() classes_refIds = [] if isinstance(inputSet2D, SetOfClasses2D): self.samplingRate = inputSet2D.getFirstItem().getRepresentative().getSamplingRate() for rep in inputSet2D.iterRepresentatives(): idClass, _ = rep.getLocation() classes_refIds.append(idClass) else: # In case the input is a SetOfAverages self.samplingRate = inputSet2D.getSamplingRate() for rep in inputSet2D.iterItems(): idClass, _ = rep.getLocation() classes_refIds.append(idClass) # Save the corresponding .mrcs file inputSet2D.writeStack(self.imgsFn) # The same method for SetOfClasses and SetOfAverages # Save the original ref ids with open(self.refIdsFn, "w") as file: for item in classes_refIds: file.write(f"{item}\n")
[docs] def clusterClasses(self): 'xmipp_cl2d_clustering -i path/to/inputAverages.mrcs -o path/to/outputDir -m 10 -M 20 -j 8' min_cluster = self.min_cluster.get() max_cluster = self.max_cluster.get() compute_threads = self.compute_threads.get() args = (" -i %s -o %s -m %d -M %d -j %d" %(self.imgsFn, self.directoryPath, min_cluster, max_cluster, compute_threads)) self.runJob("xmipp_cl2d_clustering", args)
[docs] def createOutputStep(self): output_dict = {} inputSet2D = self.inputSet2D.get() inputSet2D.loadAllProperties() result_dict_file = os.path.join(self.directoryPath, RESULT_FILE) result_dict = self.read_clusters_from_txt(result_dict_file) message = ("Classify original input set of %d images into %d groups of structural different images" % (self.inputSet2D.get().getSize(), len(result_dict))) self.summaryVar.set(message) if self.extractOption.get() == self.CLASSES or self.extractOption.get() == self.BOTH: output_dict = self.createOutputSetOfClasses(inputSet2D, result_dict, output_dict) if self.extractOption.get() == self.AVERAGES or self.extractOption.get() == self.BOTH: output_dict = self.createOutputSetOfAverages(inputSet2D, result_dict, output_dict) self._defineOutputs(**output_dict) if self.extractOption.get() == self.CLASSES or self.extractOption.get() == self.BOTH: self._defineSourceRelation(inputSet2D.getImagesPointer(), output_dict[OUTPUT_CLASSES]) if self.extractOption.get() == self.AVERAGES or self.extractOption.get() == self.BOTH: self._defineSourceRelation(self.inputSet2D, output_dict[OUTPUT_AVERAGES]) self._store()
[docs] def createOutputSetOfAverages(self, inputSet2D, result_dict, output_dict): outputRefs = self._createSetOfAverages() # We need to create always an empty set since we need to rebuild it for cluster, classesRef in result_dict.items(): self.info('For cluster %d' % cluster) self.info('We have the following ref classes: %s' % classesRef) firstTime = True for classRef in classesRef: if firstTime: # Just want to get the first ref if isinstance(inputSet2D, SetOfClasses2D): classTmp = inputSet2D.getItem("id", classRef).clone() rep = classTmp.getRepresentative().clone() else: rep = inputSet2D.getItem("id", classRef).clone() self.samplingRate = inputSet2D.getSamplingRate() self.info('Using centroid to create new Average %s' % classRef) newAvg = Particle() newAvg.copyInfo(rep) newAvg.setObjId(int(classRef)) newAvg.setClassId(int(classRef)) outputRefs.append(newAvg) firstTime = False outputRefs.setSamplingRate(self.samplingRate) output_dict[OUTPUT_AVERAGES] = outputRefs return output_dict
[docs] def createOutputSetOfClasses(self, inputSet2D, result_dict, output_dict): classes2DSet = self._createSetOfClasses2D(inputSet2D.getImagesPointer()) dictClasses = {} for cluster, classesRef in result_dict.items(): self.info('For cluster %d' % cluster) self.info('We have the following ref classes: %s' % classesRef) firstTime = True newParticles = [] for classRef in classesRef: classTmp = inputSet2D.getItem("id", classRef) if firstTime: self.info('First iter, using centroid to create new Class: %s' % classRef) newClass = Class2D() newClass.copyInfo(classTmp) newClass.setObjId(int(classRef)) newClassId = newClass.getObjId() firstTime = False for particle in classTmp.iterItems(): particle.setClassId(newClassId) newParticles.append(particle.clone()) dictClasses[newClassId] = newParticles self.info('Class particles size: %d' % len(newParticles)) classes2DSet.append(newClass) for classId, particles in dictClasses.items(): self.info('Filling Class with ID %d with %d particles' % (classId, len(particles))) class2D = classes2DSet[classId] class2D.enableAppend() for particle in particles: class2D.append(particle) classes2DSet.update(class2D) classes2DSet.write() output_dict[OUTPUT_CLASSES] = classes2DSet return output_dict
# --------------------------- INFO functions -------------------------------------------- def _summary(self): summary = [] if not hasattr(self, OUTPUT_CLASSES) and not hasattr(self, OUTPUT_AVERAGES): summary.append("Output set not ready yet.") else: summary.append(self.summaryVar.get()) return summary def _validate(self): errors = [] if ((self.extractOption.get() == self.CLASSES or self.extractOption.get() == self.BOTH) and not isinstance(self.inputSet2D.get(), SetOfClasses2D)): errors.append("The input 2D must be a SetOfClasses2D to generate a SetOfClasses2D.") return errors # ------------------------------------ Utils ----------------------------------------------
[docs] def read_clusters_from_txt(self, file_path): """ Reads a cluster dictionary from a .txt file formatted as: Cluster 0: 6 Cluster 1: 36 34 44 ... Args: file_path (str): The path to the .txt file. Returns: dict: A dictionary where the key is the cluster number and the value is a list of associated numbers. """ clusters = {} current_cluster = None with open(file_path, 'r') as file: for line in file: line = line.strip() if line.startswith("Cluster"): # Extract the cluster number current_cluster = int(line.split()[1].replace(':', '')) clusters[current_cluster] = [] elif current_cluster is not None and line.isdigit(): # Append numbers to the current cluster's list clusters[current_cluster].append(line) return clusters
# --------------------------------- Viewer functions ---------------------------------
[docs] def getClusterPlot(self): return self._getExtraPath('best_cluster_visualization_with_images.png')
[docs] def getClusterImagesPlot(self): return self._getExtraPath('all_clusters_with_labels.png')