Source code for relion.convert.convert31

# **************************************************************************
# *
# * Authors:     J.M. de la Rosa Trevin ( [1]
# *              Grigory Sharov ( [2]
# *
# * [1] SciLifeLab, Stockholm University
# * [2] MRC Laboratory of Molecular Biology, MRC-LMB
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address ''
# *
# **************************************************************************
New conversion functions dealing with Relion3.1 new star files format.
import os
import io
import numpy as np
from collections import OrderedDict
from emtable import Table

from pwem.constants import ALIGN_NONE, ALIGN_PROJ, ALIGN_2D, ALIGN_3D
from pwem.objects import (Micrograph, SetOfMicrographsBase, SetOfMovies,
                          Particle, CTFModel, Acquisition, Transform, Coordinate)
import pwem.convert.transformations as tfs

from .convert_base import WriterBase, ReaderBase
from .convert_utils import (convertBinaryFiles, locationToRelion,
from relion.constants import PARTICLE_EXTRA_LABELS

[docs]def getPixelSizeLabel(imageSet): """ Return the proper label for pixel size. """ if (isinstance(imageSet, SetOfMicrographsBase) or isinstance(imageSet, Micrograph)): return 'rlnMicrographPixelSize' else: return 'rlnImagePixelSize'
[docs]class OpticsGroups: """ Store information about optics groups in an indexable way. Existing groups can be accessed by number of name. """ def __init__(self, opticsTable): self.__fromTable(opticsTable) def __fromTable(self, opticsTable): self._dict = OrderedDict() # Also allow indexing by name self._dictName = OrderedDict() # Map optics rows both by name and by number for og in opticsTable: self.__store(og) def __store(self, og): self._dict[og.rlnOpticsGroup] = og self._dictName[og.rlnOpticsGroupName] = og def __getitem__(self, item): if isinstance(item, int): return self._dict[item] elif isinstance(item, str): return self._dictName[item] raise TypeError("Unsupported type '%s' of item '%s'" % (type(item), item)) def __contains__(self, item): return item in self._dict or item in self._dictName def __iter__(self): """ Iterate over all optics groups. """ return iter(self._dict.values()) def __len__(self): return len(self._dict) def __str__(self): return self.toString()
[docs] def first(self): """ Return first optics group. """ return next(iter(self._dict.values()))
[docs] def update(self, ogId, **kwargs): og = self.__getitem__(ogId) newOg = og._replace(**kwargs) self.__store(newOg) return newOg
[docs] def updateAll(self, **kwargs): """ Update all Optics Groups with these values. """ missing = {k: v for k, v in kwargs.items() if not self.hasColumn(k)} existing = {k: v for k, v in kwargs.items() if self.hasColumn(k)} self.addColumns(**missing) for og in self: self.update(og.rlnOpticsGroup, **existing)
[docs] def add(self, newOg): self.__store(newOg)
[docs] def hasColumn(self, colName): return hasattr(self.first(), colName)
[docs] def addColumns(self, **kwargs): """ Add new columns with default values (type inferred from it). """ items = self.first()._asdict().items() cols = [Table.Column(k, type(v)) for k, v in items] for k, v in kwargs.items(): cols.append(Table.Column(k, type(v))) t = Table(columns=cols) for og in self._dict.values(): values = og._asdict() values.update(kwargs) t.addRow(**values) self.__fromTable(t)
[docs] @staticmethod def fromStar(starFilePath): """ Create an OpticsGroups from a given STAR file. """ return OpticsGroups(Table(fileName=starFilePath, tableName='optics'))
[docs] @staticmethod def fromString(stringValue): """ Create an OpticsGroups from string content (STAR format) """ f = io.StringIO(stringValue) t = Table() t.readStar(f, tableName='optics') return OpticsGroups(t)
[docs] @staticmethod def fromImages(imageSet): acq = imageSet.getAcquisition() params = {'rlnImageSize': imageSet.getXDim(), getPixelSizeLabel(imageSet): imageSet.getSamplingRate()} if isinstance(imageSet, SetOfMovies): params['rlnMicrographOriginalPixelSize'] = imageSet.getSamplingRate() try: og = OpticsGroups.fromString(acq.opticsGroupInfo.get()) # always update sampling and image size from the set og.updateAll(**params) return og except: params.update({ 'rlnVoltage': acq.getVoltage(), 'rlnSphericalAberration': acq.getSphericalAberration(), 'rlnAmplitudeContrast': acq.getAmplitudeContrast(), }) return OpticsGroups.create(**params)
[docs] @staticmethod def create(**kwargs): opticsString1 = """ # version 30001 data_optics loop_ _rlnOpticsGroupName #1 _rlnOpticsGroup #2 _rlnMicrographOriginalPixelSize #3 _rlnVoltage #4 _rlnSphericalAberration #5 _rlnAmplitudeContrast #6 _rlnImageSize #7 _rlnImageDimensionality #8 opticsGroup1 1 1.000000 300.000000 2.700000 0.100000 256 2 """ og = OpticsGroups.fromString(opticsString1) fog = og.first() newColumns = {k: v for k, v in kwargs.items() if not hasattr(fog, k)} og.addColumns(**newColumns) og.update(1, **kwargs) return og
def _write(self, f): # Create columns from the first row items = self.first()._asdict().items() cols = [Table.Column(k, type(v)) for k, v in items] t = Table(columns=cols) for og in self._dict.values(): t.addRow(*og) t.writeStar(f, tableName='optics')
[docs] def toString(self): """ Return a string (STAR format) with the current optics groups. """ f = io.StringIO() self._write(f) result = f.getvalue() f.close() return result
[docs] def toStar(self, starFile): """ Write current optics groups to a given file. """ self._write(starFile)
[docs] def toImages(self, imageSet): """ Store the optics groups information in the image acquisition. """ imageSet.getAcquisition().opticsGroupInfo.set(self.toString())
[docs]class Writer(WriterBase): """ Helper class to convert from Scipion SetOfImages subclasses into Relion>3.1 star files (and binaries if conversion needed). """
[docs] def writeSetOfMovies(self, moviesIterable, starFile, **kwargs): self._writeSetOfMoviesOrMics(moviesIterable, starFile, 'movies', 'rlnMicrographMovieName', **kwargs)
[docs] def writeSetOfMicrographs(self, micsIterable, starFile, **kwargs): self._writeSetOfMoviesOrMics(micsIterable, starFile, 'micrographs', 'rlnMicrographName', **kwargs)
def _writeSetOfMoviesOrMics(self, imgIterable, starFile, tableName, imgLabelName, **kwargs): """ This function can be used to write either movies or micrographs star files. Input can be any iterable of these type of images (e.g set, list, etc). """ # Process the first item and create the table based # on the generated columns self._imgLabelName = imgLabelName self._postprocessImageRow = kwargs.get('postprocessImageRow', None) self._prefix = tableName[:3] micRow = OrderedDict() micRow[imgLabelName] = '' # Just to add label, proper value later iterMics = iter(imgIterable) mic = next(iterMics) if self._optics is None: self._optics = OpticsGroups.fromImages(mic) self._imageSize = mic.getXDim() self._setCtf = mic.hasCTF() extraLabels = kwargs.get('extraLabels', []) self._extraLabels = [l for l in extraLabels if mic.hasAttribute('_%s' % l)] self._micToRow(mic, micRow) if self._postprocessImageRow: self._postprocessImageRow(mic, micRow) micsTable = self._createTableFromDict(micRow) while mic is not None: micRow[imgLabelName] = self._convert(mic) self._micToRow(mic, micRow) if self._postprocessImageRow: self._postprocessImageRow(mic, micRow) micsTable.addRow(**micRow) mic = next(iterMics, None) with open(starFile, 'w') as f: f.write("# Star file generated with Scipion\n") f.write("# version 30001\n") self._optics.toStar(f) f.write("# version 30001\n") micsTable.writeStar(f, tableName=tableName) def _objToRow(self, obj, row, attributes): """ Set some attributes from the object to the row. For performance reasons, it is not validated that each attribute is already in the object, so it should be validated before. """ for attr in attributes: row[attr] = obj.getAttributeValue('_%s' % attr) def _micToRow(self, mic, row): WriterBase._micToRow(self, mic, row) # Set CTF values if self._setCtf: self._ctfToRow(mic.getCTF(), row) # Set additional labels if present self._objToRow(mic, row, self._extraLabels) row['rlnOpticsGroup'] = mic.getAttributeValue('_rlnOpticsGroup', 1) def _align2DToRow(self, alignment, row): matrix = alignment.getMatrix() shifts = tfs.translation_from_matrix(matrix) shifts *= self._pixelSize angles = -np.rad2deg(tfs.euler_from_matrix(matrix, axes='szyz')) row['rlnOriginXAngst'], row['rlnOriginYAngst'] = shifts[:2] row['rlnAnglePsi'] = -(angles[0] + angles[2]) def _alignProjToRow(self, alignment, row): matrix = np.linalg.inv(alignment.getMatrix()) shifts = -tfs.translation_from_matrix(matrix) shifts *= self._pixelSize angles = -np.rad2deg(tfs.euler_from_matrix(matrix, axes='szyz')) row['rlnOriginXAngst'], row['rlnOriginYAngst'], row['rlnOriginZAngst'] = shifts row['rlnAngleRot'], row['rlnAngleTilt'], row['rlnAnglePsi'] = angles def _partToRow(self, part, row): row['rlnImageId'] = part.getObjId() # Add coordinate information coord = part.getCoordinate() if coord is not None: x, y = coord.getPosition() row['rlnCoordinateX'] = x row['rlnCoordinateY'] = y # Add some specific coordinate attributes self._objToRow(coord, row, self._coordLabels) micName = coord.getMicName() if micName: row['rlnMicrographName'] = str(micName.replace(" ", "")) else: if coord.getMicId(): row['rlnMicrographName'] = str(coord.getMicId()) index, fn = part.getLocation() if self.outputStack: row['rlnOriginalParticleName'] = locationToRelion(index, fn) index, fn = self._counter, self._relOutputStack if self._counter > 0: self._ih.convert(part, (index, self.outputStack)) else: if self.outputDir is not None: fn = self._filesDict.get(fn, fn) row['rlnImageName'] = locationToRelion(index, fn) # Set CTF values if self._setCtf: self._ctfToRow(part.getCTF(), row) # Set alignment if necessary if self._setAlign: self._setAlign(part.getTransform(), row) # Set additional labels if present self._objToRow(part, row, self._extraLabels) # Add now the new Optics Group stuff row['rlnOpticsGroup'] = part.getAttributeValue('_rlnOpticsGroup', 1) self._counter += 1
[docs] def writeSetOfParticles(self, partsSet, starFile, **kwargs): # Process the first item and create the table based # on the generated columns self.update(['rootDir', 'outputDir', 'outputStack'], **kwargs) self._optics = OpticsGroups.fromImages(partsSet) partRow = OrderedDict() firstPart = partsSet.getFirstItem() # Convert binaries if required if self.outputStack: self._relOutputStack = os.path.relpath(self.outputStack, os.path.dirname(starFile)) if self.outputDir is not None: forceConvert = kwargs.get('forceConvert', False) self._filesDict = convertBinaryFiles(partsSet, self.outputDir, forceConvert=forceConvert) # Compute some flags from the first particle... # when flags are True, some operations will be applied to all particles self._preprocessImageRow = kwargs.get('preprocessImageRow', None) self._setCtf = kwargs.get('writeCtf', True) and firstPart.hasCTF() alignType = kwargs.get('alignType', partsSet.getAlignment()) if alignType == ALIGN_2D: self._setAlign = self._align2DToRow elif alignType == ALIGN_PROJ: self._setAlign = self._alignProjToRow elif alignType == ALIGN_3D: raise NotImplementedError( "3D alignment conversion for Relion not implemented. " "It seems the particles were generated with an incorrect " "alignment type. You may either re-launch the protocol that " "generates the particles with angles or set 'Consider previous" " alignment?' to No") elif alignType == ALIGN_NONE: self._setAlign = None else: raise TypeError("Invalid value for alignType: %s" % alignType) extraLabels = kwargs.get('extraLabels', []) extraLabels.extend(PARTICLE_EXTRA_LABELS) self._extraLabels = [l for l in extraLabels if firstPart.hasAttribute('_%s' % l)] coord = firstPart.getCoordinate() self._coordLabels = [] if coord is not None: self._coordLabels = [l for l in ['rlnClassNumber', 'rlnAutopickFigureOfMerit', 'rlnAnglePsi'] if coord.hasAttribute('_%s' % l)] self._postprocessImageRow = kwargs.get('postprocessImageRow', None) self._imageSize = firstPart.getXDim() self._pixelSize = firstPart.getSamplingRate() or 1.0 self._counter = 0 # Mark first conversion as special one firstPart.setAcquisition(partsSet.getAcquisition()) self._partToRow(firstPart, partRow) if self._postprocessImageRow: self._postprocessImageRow(firstPart, partRow) partsTable = self._createTableFromDict(partRow) partsTable.addRow(**partRow) with open(starFile, 'w') as f: # Write particles table f.write("# Star file generated with Scipion\n") f.write("\n# version 30001\n") self._optics.toStar(f) f.write("# version 30001\n") # Write header first partsWriter = Table.Writer(f) partsWriter.writeTableName('particles') partsWriter.writeHeader(partsTable.getColumns()) # Write all rows for part in partsSet: self._partToRow(part, partRow) if self._postprocessImageRow: self._postprocessImageRow(part, partRow) partsWriter.writeRowValues(partRow.values())
# partsTable.writeStarLine(f, partRow.values())
[docs]class Reader(ReaderBase): ALIGNMENT_LABELS = [ "rlnOriginXAngst", "rlnOriginYAngst", "rlnOriginZAngst", "rlnAngleRot", "rlnAngleTilt", "rlnAnglePsi", ] CTF_LABELS = [ "rlnDefocusU", "rlnDefocusV", "rlnDefocusAngle", "rlnCtfAstigmatism", "rlnCtfFigureOfMerit", "rlnCtfMaxResolution" ] COORD_LABELS = [ "rlnCoordinateX", "rlnCoordinateY", "rlnMicrographName", # extra labels below "rlnAutopickFigureOfMerit", "rlnClassNumber", "rlnAnglePsi" ] def __init__(self, **kwargs): """ """ ReaderBase.__init__(self, **kwargs)
[docs] def readSetOfParticles(self, starFile, partSet, **kwargs): """ Convert a star file into a set of particles. Params: starFile: the filename of the star file partsSet: output particles set Keyword Arguments: blockName: The name of the data block (default particles) alignType: alignment type removeDisabled: Remove disabled items """ self._preprocessImageRow = kwargs.get('preprocessImageRow', None) self._alignType = kwargs.get('alignType', ALIGN_NONE) self._postprocessImageRow = kwargs.get('postprocessImageRow', None) self._optics = OpticsGroups.fromStar(starFile) self._pixelSize = getattr(self._optics.first(), 'rlnImagePixelSize', 1.0) self._invPixelSize = 1. / self._pixelSize partsReader = Table.Reader(starFile, tableName='particles') firstRow = partsReader.getRow() self._setClassId = hasattr(firstRow, 'rlnClassNumber') self._setCtf = partsReader.hasAllColumns(self.CTF_LABELS[:3]) self._setCoord = partsReader.hasAllColumns(self.COORD_LABELS[:3]) particle = Particle() if self._setCtf: particle.setCTF(CTFModel()) self._setAcq = kwargs.get("readAcquisition", True) if self._setAcq: acq = Acquisition() self.rowToAcquisition(self._optics.first(), acq) acq.setMagnification(kwargs.get('magnification', 10000)) partSet.setAcquisition(acq) else: # readAcquisition=False ONLY during import particles # overwrite pixel size and optics self._pixelSize = kwargs.get('pixelSize', self._pixelSize) acq = partSet.getAcquisition() self._optics.updateAll( rlnVoltage=acq.getVoltage(), rlnSphericalAberration=acq.getSphericalAberration(), rlnAmplitudeContrast=acq.getAmplitudeContrast(), rlnImagePixelSize=self._pixelSize ) extraLabels = kwargs.get('extraLabels', []) + PARTICLE_EXTRA_LABELS self.createExtraLabels(particle, firstRow, extraLabels) self._rowToPart(firstRow, particle) partSet.setSamplingRate(self._pixelSize) self._optics.toImages(partSet) partSet.append(particle) for row in partsReader: self._rowToPart(row, particle) partSet.append(particle) partSet.setHasCTF(self._setCtf) partSet.setAlignment(self._alignType)
def _rowToPart(self, row, particle): particle.setObjId(getattr(row, 'rlnImageId', None)) if self._preprocessImageRow: self._preprocessImageRow(particle, row) # Decompose Relion filename index, filename = relionToLocation(row.rlnImageName) particle.setLocation(index, filename) if self._setClassId: particle.setClassId(row.rlnClassNumber) if self._setCtf: self.rowToCtf(row, particle.getCTF()) self.setParticleTransform(particle, row) self.setExtraLabels(particle, row) # TODO: coord extra labels, partId, micId, if self._setCoord: coord = self.rowToCoord(row) particle.setCoordinate(coord) if self._postprocessImageRow: self._postprocessImageRow(particle, row)
[docs] @staticmethod def rowToCoord(row): """ Create a Coordinate from the row. """ coord = Coordinate() coord.setPosition(row.rlnCoordinateX, row.rlnCoordinateY) coord.setMicName(row.rlnMicrographName) return coord
[docs] @staticmethod def rowToCtf(row, ctf): """ Create a CTFModel from the row. """ ctf.setDefocusU(row.rlnDefocusU) ctf.setDefocusV(row.rlnDefocusV) ctf.setDefocusAngle(row.rlnDefocusAngle) ctf.setResolution(row.get('rlnCtfMaxResolution', 0)) ctf.setFitQuality(row.get('rlnCtfFigureOfMerit', 0)) if hasattr(row, 'rlnPhaseShift'): ctf.setPhaseShift(row.rlnPhaseShift) ctf.standardize() if hasattr(row, 'rlnCtfImage'): ctf.setPsdFile(row.rlnCtfImage)
[docs] @staticmethod def rowToAcquisition(optics, acq): acq.setAmplitudeContrast(optics.rlnAmplitudeContrast) acq.setSphericalAberration(optics.rlnSphericalAberration) acq.setVoltage(optics.rlnVoltage)
[docs] def setParticleTransform(self, particle, row): """ Set the transform values from the row. """ if ((self._alignType == ALIGN_NONE) or not row.hasAnyColumn(self.ALIGNMENT_LABELS)): self.setParticleTransform = self.__setParticleTransformNone else: # Ensure the Transform object exists self._angles = np.zeros(3) self._shifts = np.zeros(3) particle.setTransform(Transform()) if self._alignType == ALIGN_2D: self.setParticleTransform = self.__setParticleTransform2D elif self._alignType == ALIGN_PROJ: self.setParticleTransform = self.__setParticleTransformProj else: raise TypeError("Unexpected alignment type: %s" % self._alignType) # Call again the modified function self.setParticleTransform(particle, row)
def __setParticleTransformNone(self, particle, row): particle.setTransform(None) def __setParticleTransform2D(self, particle, row): angles = self._angles shifts = self._shifts ips = self._invPixelSize def _get(label): return float(getattr(row, label, 0.)) shifts[0] = _get('rlnOriginXAngst') * ips shifts[1] = _get('rlnOriginYAngst') * ips angles[2] = -_get('rlnAnglePsi') radAngles = -np.deg2rad(angles) M = tfs.euler_matrix(radAngles[0], radAngles[1], radAngles[2], 'szyz') M[:3, 3] = shifts[:3] particle.getTransform().setMatrix(M) def __setParticleTransformProj(self, particle, row): angles = self._angles shifts = self._shifts ips = self._invPixelSize def _get(label): return float(getattr(row, label, 0.)) shifts[0] = _get('rlnOriginXAngst') * ips shifts[1] = _get('rlnOriginYAngst') * ips shifts[2] = _get('rlnOriginZAngst') * ips angles[0] = _get('rlnAngleRot') angles[1] = _get('rlnAngleTilt') angles[2] = _get('rlnAnglePsi') radAngles = -np.deg2rad(angles) # TODO: jmrt: Maybe we should test performance and consider if keeping # TODO: the matrix and not creating one everytime will make things faster M = tfs.euler_matrix(radAngles[0], radAngles[1], radAngles[2], 'szyz') M[:3, 3] = -shifts[:3] M = np.linalg.inv(M) particle.getTransform().setMatrix(M)