# **************************************************************************
# *
# * Authors: Josue Gomez Blanco (josue.gomez-blanco@mcgill.ca)
# *
# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
import logging
from pwem.convert.trigonometry import FibonacciSphere
logger = logging.getLogger(__name__)
from math import radians, degrees
import numpy as np
import matplotlib.cm as cm
from scipy.ndimage.filters import gaussian_filter
from pwem.convert.transformations import euler_from_matrix
from pyworkflow.gui.plotter import Plotter, plt
import pwem.emlib.metadata as md
import numbers
from math import atan2, sqrt, pi
from pwem import emlib
PLOT_EULER_ANGLES = 1
PLOT_PROJ_ANGLES = 2
PLOT_PROJ_DIR = 3
[docs]class EmPlotter(Plotter):
""" Class to create several plots. """
def __init__(self, x=1, y=1, mainTitle="", **kwargs):
Plotter.__init__(self, x, y, mainTitle, **kwargs)
[docs] def plotAngularDistribution(self, title, rot,
tilt, weight=[], max_p=40,
min_p=5, color='blue', colormap=None, subtitle=None):
""" Create a special type of subplot, representing the angular
distribution in 2d of weighted projections. """
if weight:
max_w = max(weight)
min_w = 0
if subtitle is None:
subtitle = 'Min weight=%(min_w).2f, Max weight=%(max_w).2f' % locals()
a = self.createSubPlot(title,
subtitle,
'', projection='polar')
label_position = a.get_rlabel_position()
a.text(np.radians(label_position + 10), a.get_rmax() / 2., 'Tilt',
rotation=label_position, ha='center', va='center')
pointSizes = []
for r, t, w in zip(rot, tilt, weight):
pointsize = int((w - min_w) / (max_w - min_w + 0.001) * (max_p - min_p) + min_p)
pointSizes.append(pointsize)
if colormap:
sc = a.scatter(rot, tilt, s=20, c=pointSizes, cmap=colormap, marker='.')
plt.colorbar(sc)
else:
a.scatter(rot, tilt, s=pointSizes, c=color, marker='.')
else:
a = self.createSubPlot(title, 'Non weighted plot', '', projection='polar')
a.scatter(rot, tilt, s=10, c=color, marker='.')
return a
[docs] def scatter3DPlot(self, x, y,z, title="3D scatter plot", drawsphere=True, markerSize=1, colormap=cm.jet, subtitle=None):
ax =self.createSubPlot(title, xlabel=None, ylabel=None, projection='3d', subtitle=subtitle)
ax.set_box_aspect(aspect=(1, 1, 1))
sc = ax.scatter(x, y, z, markerSize,c=markerSize, s=60, cmap=colormap, alpha=1)
plt.colorbar(sc)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
# ax.set_xlim([-1,1])
# ax.set_ylim([-1, 1])
# ax.set_zlim([-1, 1])
# Add a color bar which maps values to colors.
# self.figure.colorbar(cm.ScalarMappable(norm=norm, cmap=cmhot), shrink=0.5, aspect=5)
if drawsphere:
# draw sphere
u, v = np.mgrid[0:2 * np.pi:21j, 0:np.pi:11j]
x1 = np.cos(u) * np.sin(v)
y1 = np.sin(u) * np.sin(v)
z1 = np.cos(v)
ax.plot_wireframe(x1, y1, z1, color="black", alpha=0.1)
ax.plot_surface(x1, y1, z1, color="red", alpha=0.05)
# 3d axis
ax.plot([0, 1.25], [0, 0], [0, 0], color="red")
ax.plot([0, 0], [0, 1.25], [0, 0], color="green")
ax.plot([0, 0], [0, 0], [0, 1.25], color="blue")
return ax
[docs] def plotAngularDistribution3D(self, title, x,y,z, weights, subtitle, colormap=cm.jet):
return self.scatter3DPlot(x,y,z, title=title,markerSize=weights, colormap=colormap, subtitle=subtitle)
[docs] def plotAngularDistributionHistogram(self, title, data, eulerAnglesGetterCallback, colormap=cm.jet, subtitle=None):
""" Create a special type of subplot, representing the angular
distribution of projections. """
# Extract the rot and tilt from the data
thetas = []
phis = []
for item in data:
rot, tilt, psi = eulerAnglesGetterCallback(item)
thetas.append(rot)
phis.append(tilt)
thetas.append(-180)
thetas.append(180)
phis.append(0)
phis.append(180)
heatmap, xedges, yedges = np.histogram2d(thetas, phis, bins=1000)
sigma = min(max(xedges) - min(xedges), max(yedges) - min(yedges)) / 20
heatmap = gaussian_filter(heatmap, sigma=sigma)
heatmapImage = heatmap.T
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
a = self.createSubPlot(title, 'Angular distribution', '', subtitle=subtitle)
mappable = a.imshow(heatmapImage, extent=extent, origin='lower', cmap=colormap, aspect='auto')
a.set_xlabel('Rotational angle')
a.set_ylabel('Tilt angle')
plt.colorbar(mappable)
return mappable
[docs] def plotAngularDistributionFromSet(self, mdSet, title, type=PLOT_PROJ_ANGLES, colormap=cm.jet, subtitle="", **kwargs):
""" Read the values of the transformation matrix
and plot its histogram or its angular distribution.
:param type: Type of plot 1=histogram, 2=2D polar, 3=3D
:param mdSet: Set with alignment information at with item.getTransform()
:param title: Title of the plot
:param colormap: matplotlib color map
"""
def eulerAnglesGetter (item):
matrix = item.getTransform().getRotationMatrix()
matrixI = np.linalg.inv(matrix)
rot, tilt, psi = euler_from_matrix(matrix=matrixI, axes='szyz')
if tilt<0:
tilt= -tilt
rot = -rot
return degrees(rot), degrees(tilt), degrees(psi)
self.plotAngularDistributionBase(mdSet, eulerAnglesGetter, title, type, colormap, subtitle=subtitle,**kwargs)
[docs] def plotAngularDistributionBase(self, data, eulerAnglesGetterCallback, title,
type=PLOT_PROJ_ANGLES, colormap=cm.jet, subtitle="", **kwargs):
""" Read the values of the transformation matrix
and plot its histogram or its angular distribution.
:param colormap: matplotlib color map
:param data: any particles iterator containing particles with alignment information. sqlites or star files , ...
:param eulerAnglesGetterCallback: a callback to extract rot, tilt, psi IN DEGREES from each row, receiving the row/item.
:param title: Title of the plot.
:param type: 1 for histogram, 2 for polar plot, 3 for 3d plot
:param subtitle: subtitle of the plot
"""
if type==PLOT_EULER_ANGLES:
return self.plotAngularDistributionHistogram(title, data , eulerAnglesGetterCallback,
colormap=colormap, subtitle=subtitle)
else:
rots =[]
tilts= []
# Get the euler angles
for item in data:
rot, tilt, psi = eulerAnglesGetterCallback(item)
rots.append(rot)
tilts.append(tilt)
if type==PLOT_PROJ_ANGLES:
# Weight (group) rots and tilts
rots, tilts, weights = self.weightEulerAngles(rots, tilts, rotInRadians=type == PLOT_PROJ_ANGLES)
return self.plotAngularDistribution(title, rots, tilts, weight=weights, colormap=colormap, subtitle=subtitle, **kwargs)
else:
# Create discrete point of a fibonacci sphere (5000 points)
fiSph = FibonacciSphere()
Xs, Ys, Zs = self._anglesToSphereCoords(rots, tilts)
for x,y,z in zip(Xs,Ys,Zs):
fiSph.append(x,y,z)
fiSph.cleanWeights()
return self.plotAngularDistribution3D(title, fiSph.sphX, fiSph.sphY, fiSph.sphZ,
fiSph.weights, subtitle,colormap=colormap)
[docs] def weightEulerAngles(self, rots, tilts, delta=3, rotInRadians=False):
""" Receives the list of rots and tilts angles (in deg) and returns
a reduced list of rots, tilts and weights lists
:param rotCaster: method that receives the rot in degrees and returns it converted to something else, radians?"""
# Holds pairs of rot, tilt
projectionList = []
def getCloseProjectionIndex(angleRot, angleTilt):
""" Get an existing projection close to angleRot, angleTilt.
Return None if not found close enough.
"""
for index, projection in enumerate(projectionList):
if (abs(projection[0] - angleRot) <= delta and
abs(projection[1] - angleTilt) <= delta):
return index
return None
weight = 1 #1. / len(rots)
new_rots = []
new_tilts = []
weights = []
if rotInRadians:
rotCaster = np.radians
else:
rotCaster = lambda rot: rot
# Weight the rots and tilts
for rot, tilt in zip(rots, tilts):
projectionIndex = getCloseProjectionIndex(rot, tilt)
if projectionIndex is None:
projectionList.append([rot, tilt])
new_rots.append(rotCaster(rot))
new_tilts.append(tilt)
weights.append(weight)
else:
weights[projectionIndex] += weight
return new_rots, new_tilts, weights
def _anglesToSphereCoords(self, rots, tilts):
""" Converts euler angles (rot and tilts) to spherical coordinates."""
X=[]
Y=[]
Z=[]
for rot, tilt in zip(rots, tilts):
# Converts to euler direction
x, y, z = emlib.Euler_direction(rot, tilt, 0)
X.append(x)
Y.append(y)
Z.append(z)
return X, Y, Z
[docs] def plotAngularDistributionFromMd(self, mdFile, title, **kwargs):
""" Read the values of rot, tilt and weights from
the metadata and plot the angular distribution.
ANGLES are in DEGREES
In the metadata:
rot: MDL_ANGLE_ROT
tilt: MDL_ANGLE_TILT
weight: MDL_WEIGHT
"""
angMd = md.MetaData(mdFile)
if 'histogram' in kwargs:
def eulerAnglesGetter(row):
return row.getValue(md.MDL_ANGLE_ROT), row.getValue(md.MDL_ANGLE_TILT), None
class MDIter:
def __init__(self, mdObj):
self.mdObj = mdObj
def __iter__(self):
for row in md.iterRows(self.mdObj):
yield row
return self.plotAngularDistributionHistogram(title, MDIter(angMd), eulerAnglesGetter)
else:
rot = []
tilt = []
weight = []
for row in md.iterRows(angMd):
rot.append(radians(row.getValue(md.MDL_ANGLE_ROT)))
tilt.append(row.getValue(md.MDL_ANGLE_TILT))
weight.append(row.getValue(md.MDL_WEIGHT))
return self.plotAngularDistribution(title, rot, tilt, weight, **kwargs)
[docs] def plotHist(self, yValues, nbins, color='blue', **kwargs):
""" Create a histogram. """
# In some cases yValues is a generator, which cannot be indexed
self.hist(list(yValues), nbins, facecolor=color, **kwargs)
[docs] def plotScatter(self, xValues, yValues, color='blue', **kwargs):
""" Create a scatter plot. """
self.scatterP(xValues, yValues, c=color, **kwargs)
[docs] def plotMatrix(self, img
, matrix
, vminData
, vmaxData
, cmap='jet'
, xticksLablesMajor=None
, yticksLablesMajor=None
, rotationX=90.
, rotationY=0.
, **kwargs):
interpolation = kwargs.pop('interpolation', "none")
plot = img.imshow(matrix, interpolation=interpolation, cmap=cmap,
vmin=vminData, vmax=vmaxData, **kwargs)
if xticksLablesMajor is not None:
plt.xticks(range(len(xticksLablesMajor)),
xticksLablesMajor[:len(xticksLablesMajor)],
rotation=rotationX)
if yticksLablesMajor is not None:
plt.yticks(range(len(yticksLablesMajor)),
yticksLablesMajor[:len(yticksLablesMajor)],
rotation=rotationY)
return plot
[docs] def plotData(self, xValues, yValues, color='blue', **kwargs):
""" Shortcut function to plot some values.
Params:
xValues: list of values to show in x-axis
yValues: list of values to show as values in y-axis
color: color for the plot.
**kwargs: keyword arguments that accepts:
marker, linestyle
"""
self.plot(xValues, yValues, color, **kwargs)
[docs] def plotDataBar(self, xValues, yValues, width, color='blue', **kwargs):
""" Shortcut function to plot some values.
Params:
xValues: list of values to show in x-axis
yValues: list of values to show as values in y-axis
color: color for the plot.
**kwargs: keyword arguments that accepts:
marker, linestyle
"""
self.bar(xValues, yValues, width=width, color=color, **kwargs)
[docs] @classmethod
def createFromFile(cls, dbName, dbPreffix, plotType, columnsStr, colorsStr, linesStr,
markersStr, xcolumn, ylabel, xlabel, title, bins, orderColumn,
orderDirection):
columns = columnsStr.split()
colors = colorsStr.split()
lines = linesStr.split()
markers = markersStr.split()
data = PlotData(dbName, dbPreffix, orderColumn, orderDirection)
plotter = Plotter(windowTitle=title)
ax = plotter.createSubPlot(title, xlabel, ylabel)
xvalues = data.getColumnValues(xcolumn) if xcolumn else range(0, data.getSize())
for i, col in enumerate(columns):
yvalues = data.getColumnValues(col)
color = colors[i]
line = lines[i]
colLabel = col if not col.startswith("_") else col[1:]
if bins:
yvalues = data._removeInfinites(yvalues)
ax.hist(yvalues, bins=int(bins), color=color, linestyle=line, label=colLabel)
else:
if plotType == 'Plot':
marker = (markers[i] if not markers[i] == 'none' else None)
ax.plot(xvalues, yvalues, color, marker=marker, linestyle=line, label=colLabel)
else:
ax.scatter(xvalues, yvalues, c=color, label=col, alpha=0.5)
ax.legend()
return plotter
[docs]class PlotData:
""" Small wrapper around table data such as: sqlite or metadata
files. """
def __init__(self, fileName, tableName, orderColumn, orderDirection):
self._orderColumn = orderColumn
self._orderDirection = orderDirection
if fileName.endswith(".db") or fileName.endswith(".sqlite"):
self._table = self._loadSet(fileName, tableName)
self.getColumnValues = self._getValuesFromSet
self.getSize = self._table.getSize
else: # assume a metadata file
self._table = self._loadMd(fileName, tableName)
self.getColumnValues = self._getValuesFromMd
self.getSize = self._table.size
def _loadSet(self, dbName, dbPreffix):
from pyworkflow.mapper.sqlite import SqliteFlatDb
db = SqliteFlatDb(dbName=dbName, tablePrefix=dbPreffix)
if dbPreffix:
setClassName = "SetOf%ss" % db.getSelfClassName()
else:
setClassName = db.getProperty('self') # get the set class name
# FIXME: Check why the import is here
from pwem import Domain
setObj = Domain.getObjects()[setClassName](filename=dbName, prefix=dbPreffix)
return setObj
def _getValuesFromSet(self, columnName):
return [self._getValue(obj, columnName)
for obj in self._table.iterItems(orderBy=self._orderColumn,
direction=self._orderDirection)]
@staticmethod
def _removeInfinites(values):
newValues = []
for value in values:
if isinstance(value, numbers.Number) and value < float("Inf"):
newValues.append(value)
return newValues
def _loadMd(self, fileName, tableName):
label = md.str2Label(self._orderColumn)
tableMd = md.MetaData('%s@%s' % (tableName, fileName))
tableMd.sort(label) # FIXME: use order direction
# TODO: sort metadata by self._orderColumn
return tableMd
def _getValuesFromMd(self, columnName):
label = md.str2Label(columnName)
return [self._table.getValue(label, objId) for objId in self._table]
def _getValue(self, obj, column):
if column == 'id':
return obj.getObjId()
return obj.getNestedValue(column)
# Functions for the angular distribution. Maybe they could go to other place?
[docs]def magnitude(x, y, z):
"""Returns the magnitude of the vector."""
return sqrt(x * x + y * y + z * z)
[docs]def to_spherical(x, y, z):
"""Converts a cartesian coordinate (x, y, z) into a spherical one (radius, theta, phi) in radians.
theta ranges from 0 t PI
phi ranges from -PI to PI
"""
radius = magnitude(x, y, z)
theta = atan2(sqrt(x * x + y * y), z)
phi = atan2(y, x)
return (radius, theta, phi)
[docs]def eulerAngles_to_2D(rot, tilt, psi):
"""Converts euler angles to their 2D representation for a polar plot"""
x, y, z = emlib.Euler_direction(rot,
tilt,
psi)
# f.write(f".sphere {x} {y} {z} .01\n")
# radius, theta, phi = to_spherical(x, y, z)
## may be radius, theta, phi = to_spherical(x, z, y)
radius, theta, phi = to_spherical(y, z, x)
if phi > pi:
phi -= 2 * pi
if phi < 0:
phi = - phi
theta += pi
return theta, phi