# ******************************************************************************
# *
# * Authors: Daniel Marchan Torres (da.marchan@cnb.csic.es)
# *
# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# ******************************************************************************
import os.path
from pwem.objects.data import Class2D, Particle, SetOfClasses2D, SetOfAverages
from pwem.protocols import ProtAnalysis2D
from pyworkflow.protocol.params import (PointerParam, IntParam,
EnumParam, LEVEL_ADVANCED, LT, GT)
from pyworkflow import NEW, BETA
from xmipp3.base import XmippProtocol
FN = "class_representatives"
RESULT_FILE = 'best_clusters_with_names.txt'
OUTPUT_CLASSES = 'outputClasses'
OUTPUT_AVERAGES = 'outputAverages'
[docs]class XmippProtCL2DClustering(ProtAnalysis2D, XmippProtocol):
"""Groups similar 2D class averages using clustering. This process helps
identify homogeneous subsets within the dataset, improving classification
and downstream analysis.
AI Generated:
This protocol takes a collection of 2D class representatives (either a SetOfClasses2D or a SetOfAverages) and groups them into a smaller number of clusters of similar-looking averages. The idea is not to re-align particles or to re-run 2D classification, but to organize the existing 2D summaries into homogeneous groups. This is particularly handy when you have many 2D classes and you want to quickly identify families (similar views, similar structural states, or recurring artifacts) and extract a concise set of “representative” images.
What you provide as input
You provide inputSet2D, which can be either:
SetOfClasses2D: the protocol will use the representatives of each class
(i.e., the class averages) as the images to cluster.
SetOfAverages: the protocol will use those averages directly.
Internally, it records the sampling rate from the input: for classes it
reads it from the first representative; for averages it reads it from the
averages set.
How the protocol prepares data (convert step)
Before clustering, the protocol exports the input representatives into a
single stack:
class_representatives.mrcs (stored under the protocol extra folder)
At the same time, it writes a simple text file:
class_representatives.txt
This second file is just the list of the original reference IDs (the image
indices / ids associated with each representative in the stack). That file
is crucial later because the clustering program will work on the exported
stack, and then Scipion needs a way to map “cluster members” back to the
original class/average IDs.
Choosing the search range for the number of clusters
The clustering program searches for an “optimal” number of clusters, but
you can constrain the search:
Minimum number of clusters (min_cluster): by default 10. This is the lower
bound of the scan; it prevents the program from collapsing everything into
a very small number of groups too early.
Maximum number of clusters (max_cluster): by default -1, meaning “use the
internal default”, which here is intended to behave like N_classes − 2
(i.e., search up to almost the full number of images). If you set a smaller
maximum, you restrict how fine-grained clustering is allowed to be.
There is also a computational control:
Number of computational threads (compute_threads): how many CPU threads the
clustering executable uses (default 8). More threads typically means faster
clustering.
The clustering run (xmipp_cl2d_clustering)
Once the representative stack exists, the protocol launches:
xmipp_cl2d_clustering -i class_representatives.mrcs -o <extraDir> -m <min>
-M <max> -j <threads>
The program writes its results into the output directory, including a key
text file:
best_clusters_with_names.txt
This file is the authoritative description of which original representatives
ended up in each cluster.
What you can ask the protocol to produce
In the Output section you choose an “Extraction option”:
Classes
The protocol builds a new SetOfClasses2D where each output class corresponds
to one cluster. Importantly, this output is only possible if the input is
actually a SetOfClasses2D, because it needs to pull the underlying particles
from each original class and regroup them.
Averages
The protocol extracts one “representative” per cluster and outputs a
SetOfAverages. In this implementation the representative is simply the first
class/average ID listed for that cluster in the results file (treated as
the “centroid” by convention in the output text). So you get one average
per cluster, which is a compact summary of the diversity in your original
set.
Both
It creates the clustered SetOfClasses2D (clusters as classes) and also the
SetOfAverages (one representative per cluster).
How output objects are built (and what that implies)
Output: SetOfAverages (cluster representatives)
For each cluster, the protocol reads the list of member IDs from
best_clusters_with_names.txt and takes the first one as the cluster
representative. If the input was classes, it clones the representative
image of that class. If the input was averages, it clones that average. It
then creates a Particle object (used here as an image container) and sets
its objId and classId to the chosen ID. Finally it stores all these into a
SetOfAverages and sets the sampling rate.
Practical implication: the “representative” is not a newly computed
centroid; it is a selected existing average/class representative.
Output: SetOfClasses2D (clusters as new classes)
This is only allowed when the input is a SetOfClasses2D (the protocol
explicitly validates this). For each cluster:
It creates one new Class2D, using the first member’s class as the template
(copying metadata/info).
It then iterates over all particles inside each member class, reassigns
their classId to the new cluster-class id, clones them, and accumulates
them.
After creating all new classes, it fills each class with its collected
particles and writes the set.
Practical implication: the resulting clustered classes contain the union of
particles from all original classes that were grouped together. This is a
re-grouping operation; it does not re-align or recompute averages inside
this protocol.
Output bookkeeping and summary
The protocol sets a summary string of the form:
“Classify original input set of N images into K groups…”
where N is the number of input classes/averages and K is the number of
clusters found (i.e., the number of “Cluster X” blocks parsed in the result
file).
It also defines source relations so provenance is preserved: clustered
classes are linked to the original particle images pointer; cluster
representatives are linked to the original input 2D set.
Validation rules and common gotchas
If you request Classes (or Both), the input must be a SetOfClasses2D. If
your input is a SetOfAverages, you can only extract Averages (cluster
representatives), because there are no underlying per-class particle
memberships to regroup.
The protocol relies on the format of best_clusters_with_names.txt (lines
like Cluster 0: followed by numeric IDs). If that file is missing or
malformed, output creation will fail because it cannot reconstruct the
cluster dictionary.
Extra visualization hooks
The protocol exposes paths to two PNGs (presumably generated by the
clustering executable) that a viewer can show:
best_cluster_visualization_with_images.png
all_clusters_with_labels.png
They’re meant to help you visually assess cluster separation and membership
at a glance.
"""
_label = 'clustering 2d classes'
_devStatus = BETA
_possibleOutputs = {OUTPUT_CLASSES: SetOfClasses2D,
OUTPUT_AVERAGES: SetOfAverages}
CLASSES = 0
AVERAGES = 1
BOTH = 2
def __init__(self, **args):
ProtAnalysis2D.__init__(self, **args)
#--------------------------- DEFINE param functions ------------------------
def _defineParams(self, form):
form.addSection(label='Input')
form.addParam('inputSet2D', PointerParam,
label="Input 2D images",
important=True, pointerClass='SetOfClasses2D, SetOfAverages',
help='Select the input classes or input averages to be clustered.')
form.addParam('min_cluster', IntParam, label='Minimum number of clusters',
default=10, expertLevel=LEVEL_ADVANCED,
validators=[GT(1, 'Error must be greater than 1')],
help=' This number will limit the search for the optimum number of clusters. '
'By default, the 2D averages will start searching for the optimum number of clusters'
'with a minimum number of 10 classes.')
form.addParam('max_cluster', IntParam, label='Maximum number of clusters',
default=-1, expertLevel=LEVEL_ADVANCED,
validators=[LT(50, 'Error must be smaller than the number of classes - 2.')],
help='This number will limit the search for the optimum number of clusters. '
'If -1 then it will act as default. By default, the 2D averages will end searching'
'for the optimum number of clusters until a maximum number of N_classes - 2.')
form.addParam('compute_threads', IntParam, label='Number of computational threads',
default=8, expertLevel=LEVEL_ADVANCED,
validators=[
GT(0, 'Error must be greater than 0.')],
help=' By default, the program will use 8 threads for computation.'
'The higher the number the fastest the computation will be')
form.addSection(label='Output')
form.addParam('extractOption', EnumParam,
choices=['Classes', 'Averages', 'Both'],
default=self.CLASSES,
label="Extraction option", display=EnumParam.DISPLAY_COMBO,
help='Select an option to extract from the 2D Classes: \n '
'_Classes_: Create a new set of 2D classes with the respective cluster distribution. \n '
'_Averages_: Extract the representatives of each cluster. This are the most representative averages. \n'
'_Both_: Create a new set of 2D classes and extract their representatives.')
# --------------------------- INSERT steps functions ------------------------
def _insertAllSteps(self):
convertStep = self._insertFunctionStep(self.convertStep)
clusterStep = self._insertFunctionStep(self.clusterClasses, prerequisites=convertStep)
self._insertFunctionStep(self.createOutputStep, prerequisites=clusterStep)
[docs] def convertStep(self):
self.info('Writting class representatives')
self.directoryPath = self._getExtraPath()
self.imgsFn = os.path.join(self.directoryPath, FN + ".mrcs")
self.refIdsFn = os.path.join(self.directoryPath, FN + ".txt")
inputSet2D = self.inputSet2D.get()
classes_refIds = []
if isinstance(inputSet2D, SetOfClasses2D):
self.samplingRate = inputSet2D.getFirstItem().getRepresentative().getSamplingRate()
for rep in inputSet2D.iterRepresentatives():
idClass, _ = rep.getLocation()
classes_refIds.append(idClass)
else: # In case the input is a SetOfAverages
self.samplingRate = inputSet2D.getSamplingRate()
for rep in inputSet2D.iterItems():
idClass, _ = rep.getLocation()
classes_refIds.append(idClass)
# Save the corresponding .mrcs file
inputSet2D.writeStack(self.imgsFn) # The same method for SetOfClasses and SetOfAverages
# Save the original ref ids
with open(self.refIdsFn, "w") as file:
for item in classes_refIds:
file.write(f"{item}\n")
[docs] def clusterClasses(self):
'xmipp_cl2d_clustering -i path/to/inputAverages.mrcs -o path/to/outputDir -m 10 -M 20 -j 8'
min_cluster = self.min_cluster.get()
max_cluster = self.max_cluster.get()
compute_threads = self.compute_threads.get()
args = (" -i %s -o %s -m %d -M %d -j %d"
%(self.imgsFn, self.directoryPath, min_cluster, max_cluster, compute_threads))
self.runJob("xmipp_cl2d_clustering", args)
[docs] def createOutputStep(self):
output_dict = {}
inputSet2D = self.inputSet2D.get()
inputSet2D.loadAllProperties()
result_dict_file = os.path.join(self.directoryPath, RESULT_FILE)
result_dict = self.read_clusters_from_txt(result_dict_file)
message = ("Classify original input set of %d images into %d groups of structural different images"
% (self.inputSet2D.get().getSize(), len(result_dict)))
self.summaryVar.set(message)
if self.extractOption.get() == self.CLASSES or self.extractOption.get() == self.BOTH:
output_dict = self.createOutputSetOfClasses(inputSet2D, result_dict, output_dict)
if self.extractOption.get() == self.AVERAGES or self.extractOption.get() == self.BOTH:
output_dict = self.createOutputSetOfAverages(inputSet2D, result_dict, output_dict)
self._defineOutputs(**output_dict)
if self.extractOption.get() == self.CLASSES or self.extractOption.get() == self.BOTH:
self._defineSourceRelation(inputSet2D.getImagesPointer(), output_dict[OUTPUT_CLASSES])
if self.extractOption.get() == self.AVERAGES or self.extractOption.get() == self.BOTH:
self._defineSourceRelation(self.inputSet2D, output_dict[OUTPUT_AVERAGES])
self._store()
[docs] def createOutputSetOfAverages(self, inputSet2D, result_dict, output_dict):
outputRefs = self._createSetOfAverages() # We need to create always an empty set since we need to rebuild it
for cluster, classesRef in result_dict.items():
self.info('For cluster %d' % cluster)
self.info('We have the following ref classes: %s' % classesRef)
firstTime = True
for classRef in classesRef:
if firstTime: # Just want to get the first ref
if isinstance(inputSet2D, SetOfClasses2D):
classTmp = inputSet2D.getItem("id", classRef).clone()
rep = classTmp.getRepresentative().clone()
else:
rep = inputSet2D.getItem("id", classRef).clone()
self.samplingRate = inputSet2D.getSamplingRate()
self.info('Using centroid to create new Average %s' % classRef)
newAvg = Particle()
newAvg.copyInfo(rep)
newAvg.setObjId(int(classRef))
newAvg.setClassId(int(classRef))
outputRefs.append(newAvg)
firstTime = False
outputRefs.setSamplingRate(self.samplingRate)
output_dict[OUTPUT_AVERAGES] = outputRefs
return output_dict
[docs] def createOutputSetOfClasses(self, inputSet2D, result_dict, output_dict):
classes2DSet = self._createSetOfClasses2D(inputSet2D.getImagesPointer())
dictClasses = {}
for cluster, classesRef in result_dict.items():
self.info('For cluster %d' % cluster)
self.info('We have the following ref classes: %s' % classesRef)
firstTime = True
newParticles = []
for classRef in classesRef:
classTmp = inputSet2D.getItem("id", classRef)
if firstTime:
self.info('First iter, using centroid to create new Class: %s' % classRef)
newClass = Class2D()
newClass.copyInfo(classTmp)
newClass.setObjId(int(classRef))
newClassId = newClass.getObjId()
firstTime = False
for particle in classTmp.iterItems():
particle.setClassId(newClassId)
newParticles.append(particle.clone())
dictClasses[newClassId] = newParticles
self.info('Class particles size: %d' % len(newParticles))
classes2DSet.append(newClass)
for classId, particles in dictClasses.items():
self.info('Filling Class with ID %d with %d particles' % (classId, len(particles)))
class2D = classes2DSet[classId]
class2D.enableAppend()
for particle in particles:
class2D.append(particle)
classes2DSet.update(class2D)
classes2DSet.write()
output_dict[OUTPUT_CLASSES] = classes2DSet
return output_dict
# --------------------------- INFO functions --------------------------------------------
def _summary(self):
summary = []
if not hasattr(self, OUTPUT_CLASSES) and not hasattr(self, OUTPUT_AVERAGES):
summary.append("Output set not ready yet.")
else:
summary.append(self.summaryVar.get())
return summary
def _validate(self):
errors = []
if ((self.extractOption.get() == self.CLASSES or self.extractOption.get() == self.BOTH)
and not isinstance(self.inputSet2D.get(), SetOfClasses2D)):
errors.append("The input 2D must be a SetOfClasses2D to generate a SetOfClasses2D.")
return errors
# ------------------------------------ Utils ----------------------------------------------
[docs] def read_clusters_from_txt(self, file_path):
"""
Reads a cluster dictionary from a .txt file formatted as:
Cluster 0:
6
Cluster 1:
36
34
44
...
Args:
file_path (str): The path to the .txt file.
Returns:
dict: A dictionary where the key is the cluster number and the value is a list of associated numbers.
"""
clusters = {}
current_cluster = None
with open(file_path, 'r') as file:
for line in file:
line = line.strip()
if line.startswith("Cluster"):
# Extract the cluster number
current_cluster = int(line.split()[1].replace(':', ''))
clusters[current_cluster] = []
elif current_cluster is not None and line.isdigit():
# Append numbers to the current cluster's list
clusters[current_cluster].append(line)
return clusters
# --------------------------------- Viewer functions ---------------------------------
[docs] def getClusterPlot(self):
return self._getExtraPath('best_cluster_visualization_with_images.png')
[docs] def getClusterImagesPlot(self):
return self._getExtraPath('all_clusters_with_labels.png')