Source code for relion.viewers.viewer_ctfrefine

# ******************************************************************************
# *
# * Authors:    Roberto Marabini       (roberto@cnb.csic.es) [1]
# *             J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se) [2]
# *
# * [1] Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# * [2] SciLifeLab, Stockholm University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# ******************************************************************************

import sys
import matplotlib as mpl
import numpy as np

from pwem.viewers.plotter import plt
from pwem.emlib.image import ImageHandler
from pyworkflow.viewer import ProtocolViewer

from .viewer_base import *
from ..objects import CtfRefineGlobalInfo
from ..protocols import ProtRelionCtfRefinement


[docs]class ProtCtfRefineViewer(ProtocolViewer): """ Viewer for Relion CTF refine results. """ _label = 'ctf refine viewer' _environments = [DESKTOP_TKINTER, WEB_DJANGO] _targets = [ProtRelionCtfRefinement] def __init__(self, **kwargs): ProtocolViewer.__init__(self, **kwargs) self.protocol._initialize() self._micInfoList = None self.xMax = None self.yMax = None self._currentMicIndex = 0 self._oldCurrentMicIndex = 0 self._currentMicId = 1 self._loadAnalyzeInfo() def _defineParams(self, form): self._env = os.environ.copy() showBeamTilt = self.protocol.doBeamtiltEstimation.get() showTrefoil = self.protocol.doEstimateTrefoil.get() showTetrafoil = self.protocol.doEstimate4thOrder.get() showDefocus = self.protocol.doCtfFitting.get() showAnisoMag = self.protocol.estimateAnisoMag.get() form.addSection(label="Results") form.addParam('useMatplotlib', params.BooleanParam, default=True, label='Use matplotlib for display', condition="not {}".format(showDefocus), help='If False, images will be displayed with ImageJ') form.addParam('displayAnisoMag', params.LabelParam, label="Show X/Y mag. anisotropy estimation", condition="{}".format(showAnisoMag), help="Display four images (X and Y): (1) phase " "differences from which this estimate was " "derived and\n(2) the model fitted through it.") form.addParam('displayDefocus', params.LabelParam, label="Show defocus estimation", condition="{}".format(showDefocus), help="Display the defocus estimation.\n " "Plot defocus difference (Angstroms) vs " "position in micrograph\n" "You may move between micrographs by using: \n\n" "Left/Right keys (move +1/-1 micrograph)\n" "Up/Down keys (move +10/-10 micrographs)\n" "Page_up/Page_down keys (move +100/-100 " "micrographs)\n" "Home/End keys (move +1000/-1000 micrographs)") form.addParam('displayBeamTilt', params.LabelParam, label="Show beam tilt estimation", condition="{} and not {}".format(showBeamTilt, showTrefoil), help="Display two images: (1) phase differences from " "which this estimate was derived and\n" "(2) the model fitted through it.") form.addParam('displayTrefoil', params.LabelParam, label="Show beam tilt and 3-fold astigmatism estimation", condition="{}".format(showTrefoil), help="Display two images: (1) phase differences from " "which this estimate was derived and\n" "(2) the model fitted through it.") form.addParam('displayTetrafoil', params.LabelParam, label="Show 4-fold aberrations estimation", condition="{}".format(showTetrafoil), help="Display two images: (1) phase differences from " "which this estimate was derived and\n" "(2) the model fitted through it.") form.addParam('displayParticles', params.LabelParam, label="Display particles", help="See the particles with the new CTF " "values") def _getVisualizeDict(self): return{ 'displayAnisoMag': self._displayAnisoMag, 'displayDefocus': self._displayDefocus, 'displayBeamTilt': self._displayBeamTilt, 'displayTrefoil': self._displayTrefoil, 'displayTetrafoil': self._displayTetrafoil, 'displayParticles': self._displayParticles } def _displayAnisoMag(self, param=None): obs_x = self.protocol._getFileName("mag_obs_x", og=1) obs_y = self.protocol._getFileName("mag_obs_y", og=1) fit_x = self.protocol._getFileName("mag_fit_x", og=1) fit_y = self.protocol._getFileName("mag_fit_y", og=1) return self._showImages(obs_x, obs_y, fit_x, fit_y, title="Anisotropy magnification") def _displayDefocus(self, e=None): """Show matplotlib with defocus values.""" micInfo = self._micInfoList[self._currentMicIndex] self._currentMicId = micInfo.micId.get() # disable default binding for arrows # because I want to use them # to navigate between micrographs # wrap in try, except because matplotlib will raise # an exception if the value is not in the list try: mpl.rcParams['keymap.back'].remove('left') mpl.rcParams['keymap.forward'].remove('right') except: pass self.plotter = EmPlotter(windowTitle="CTF Refinement", x=1, y=2) self._plotDefocusStdev() self.fig = self.plotter.getFigure() self.ax2 = self.plotter.createSubPlot( self._getTitle(micInfo), "Mic-Xdim", "Mic-Ydim", xpos=1, ypos=2) # call self.press after pressing any key self.fig.canvas.mpl_connect('key_press_event', self.press) # Maximize plot, valid for the 3 most common backends # Not sure if we need it since scipion installs it own TK # but I guess somebody may use the system one backend = mpl.get_backend() manager = plt.get_current_fig_manager() if backend == 'QT': # Option 1 # QT backend manager.window.showMaximized() elif backend == 'TkAgg': # Option 2 # TkAgg backend manager.resize(*manager.window.maxsize()) elif backend == 'WX': # Option 3 # WX backend manager.frame.Maximize(True) self.show() def _showImages(self, *imgs, **kwargs): title = kwargs.get('title', '') if self.useMatplotlib: return self._showImagesMatplotlib(title, *imgs) else: return [DataView(img) for img in imgs] def _showImagesMatplotlib(self, title, *imgs): ih = ImageHandler() xdim = 2 if len(imgs) > 2 else 1 ydim = 2 plotter = EmPlotter(windowTitle=title, x=xdim, y=ydim, figsize=(8, 6)) positions = [(1, 1), (1, 2), (2, 1), (2, 2)] for i, imgFn in enumerate(imgs): x, y = positions[i] ax = plotter.createSubPlot("", "x", "y", x, y) img = ih.read(imgFn) ax.imshow(img.getData(), cmap='jet') # with mrcfile.open(img.replace(":mrc", "")) as mrc: # im = ax.imshow(mrc.data, cmap='jet') return [plotter] def _displayBeamTilt(self, param=None): beamtilt_obs = self.protocol._getFileName("beamtilt_obs", og=1) beamtilt_fit = self.protocol._getFileName("beamtilt_fit", og=1) return self._showImages(beamtilt_obs, beamtilt_fit, title="Beam Tilt") def _displayTrefoil(self, param=None): trefoil_obs = self.protocol._getFileName("beamtilt_obs", og=1) trefoil_fit = self.protocol._getFileName("trefoil_fit", og=1) return self._showImages(trefoil_obs, trefoil_fit, title="Trefoil") def _displayTetrafoil(self, param=None): tetrafoil_obs = self.protocol._getFileName("tetrafoil_obs", og=1) tetrafoil_fit = self.protocol._getFileName("tetrafoil_fit", og=1) return self._showImages(tetrafoil_obs, tetrafoil_fit, title="Tetrafoil")
[docs] def createScipionPartView(self, filename): inputParticlesId = self.protocol.inputParticles.get().strId() labels = 'enabled id _size _filename ' labels += ' _ctfModel._defocusU _ctfModel._defocusV ' if self.protocol.doBeamtiltEstimation: labels += ' _rlnBeamTiltX _rlnBeamTiltY' viewParams = {showj.ORDER: labels, showj.VISIBLE: labels, showj.MODE: showj.MODE_MD, showj.RENDER: '_filename', 'labels': 'id', } return ObjectView(self._project, self.protocol.strId(), filename, other=inputParticlesId, env=self._env, viewParams=viewParams)
def _displayParticles(self, param=None): views = [] fn = self.protocol.outputParticles.getFileName() v = self.createScipionPartView(fn) views.append(v) return views # ------------------- UTILS functions ------------------------- def _loadAnalyzeInfo(self): # Only load once if self._micInfoList is None: ctfInfoFn = self.protocol._getFileName("ctf_sqlite") if not os.path.exists(ctfInfoFn): ctfInfo = self.protocol.createGlobalInfo(ctfInfoFn) else: ctfInfo = CtfRefineGlobalInfo(ctfInfoFn) self._micInfoList = [mi.clone() for mi in ctfInfo] self.xMax, self.yMax = ctfInfo.getMaxXY() self.ctfInfoMapper = ctfInfo ctfInfo.close() self.len_micInfoList = len(self._micInfoList) micList = [mi.micId.get() for mi in self._micInfoList] # instead of creating this dict we have # access the mapper and make a query to the database self.micDict = {k: v for v, k in enumerate(micList)}
[docs] def onClick(self, event): # try is needed because if clicked outside plot # xdata, ydata are Nonetype try: ix, iy = int(round(event.xdata)), int(round(event.ydata)) # once the user has selected a point # he have a pair of float numbers, # search for the closest micrograph # with the right stdev in a neighbourhood if ix <= self.maxMicId: while ix not in self.micDict: ix += 1 iix = self.micDict[ix] start = max(0, iix-40) end = min(iix+40, self.maxMicId) dist = np.sqrt((np.array(self.x[start:end]) - ix) ** 2 + (np.array(self.y[start:end]) - iy) ** 2) self._currentMicId = self.x[start + np.argmin(dist)] self._oldCurrentMicIndex = self._currentMicIndex self._currentMicIndex = self.micDict[self._currentMicId] self.show() except: pass
def _plotDefocusStdev(self, e=None): self.fig = self.plotter.getFigure() self.ax1 = self.plotter.createSubPlot("Defocus stdev per Micrograph\n" "Click on any point to get the " "corresponding micrograph\n " "in the defocus plot", "# Micrograph", "stdev", xpos=1, ypos=1) self.fig.canvas.mpl_connect('button_press_event', self.onClick) self.ax1.grid(True) self.maxMicId = self._micInfoList[-1].micId.get() self.x = [mi.micId.get() for mi in self._micInfoList] self.y = [mi.stdev.get() for mi in self._micInfoList] self.ax1.scatter(self.x, self.y, s=50, marker='o', c='blue') def _getTitle(self, micInfo): return ("Use arrows or Page up/Down or Home/End to navigate.\n" "Mic = %s (%d)\nColorBar indicates defocus difference" % (micInfo.micName.get()[-40:], micInfo.micId))
[docs] def press(self, event): """ Change the currently shown micrograph when a key is pressed (increment/decrement) """ sys.stdout.flush() if event.key == 'q': plt.close('all') shiftDict = { 'left': -1, 'right': 1, 'up': 10, 'down': -10, 'pageup': -100, 'pagedown': 100, 'home': -1000, 'end': 1000 } # if pressed key is not left, up, etc, do nothing if event.key in shiftDict: shift = shiftDict[event.key] # Check the new micrograph index is between first and last newIndex = self._currentMicIndex + shift self._oldCurrentMicIndex = self._currentMicIndex self._currentMicIndex = min(max(0, newIndex), self.len_micInfoList - 1) self.show()
[docs] def show(self, event=None): """ Draw plot """ # stdev plot self.ax1.scatter(self.x[self._oldCurrentMicIndex], self.y[self._oldCurrentMicIndex], s=50, marker='o', c='blue') self.ax1.scatter(self.x[self._currentMicIndex], self.y[self._currentMicIndex], s=50, marker='o', c='red') # defocus plot micInfo = self._micInfoList[self._currentMicIndex] if event is None: # I need to clear the plot otherwise # old points are not removed self.ax2.clear() self.ax2.margins(0.05) self.ax2.set_title(self._getTitle(micInfo)) newFontSize = self.plotter.plot_axis_fontsize + 2 self.ax2.set_xlabel("Mic Xdim (px)", fontsize=newFontSize) self.ax2.set_ylabel("Mic Ydim (px)", fontsize=newFontSize) self.ax2.set_xlim(0, self.xMax) # np.max(micInfo.x)) self.ax2.set_ylim(0, self.yMax) # np.max(micInfo.y)) self.ax2.grid(True) # if I do not use subplots_adjust the window shrinks # after redraw plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) sc2 = self.ax2.scatter(micInfo.x, micInfo.y, c=micInfo.defocusDiff, s=100, marker='o') self.plotter.getColorBar(sc2) self.plotter.show()