import logging
import os
from subprocess import Popen
logger = logging.getLogger(__name__)
from datetime import datetime
import sqlite3
import pyworkflow as pw
import metadataviewer
from metadataviewer.dao.model import IDAO
from metadataviewer.model import Table, Column, BoolRenderer, ImageRenderer, StrRenderer
from metadataviewer.model.renderers import ImageReader, ExternalProgram
from functools import lru_cache
from PIL import Image
import mrcfile
import numpy as np
SCIPION_OBJECT_ID = "SCIPION_OBJECT_ID"
SCIPION_PORT = "SCIPION_PORT"
[docs]class MRCImageReader(ImageReader):
[docs] @classmethod
def open(cls, path:str):
isVol = path.endswith(":mrc")
path = path.replace(":mrc", "")
if not "@" in path:
path ="1@"+path
filePath = path.split('@')
index = int(filePath[0])
fileName = filePath[-1]
mrcImg = cls.getMrcImage(fileName)
if mrcImg.is_volume() or isVol:
dim = mrcImg.data.shape
x = int(dim[0] /2)
imfloat = mrcImg.data[x,:,:]
elif mrcImg.is_image_stack():
imfloat = mrcImg.data[index-1]
else:
imfloat = mrcImg.data
iMax = imfloat.max()
iMin = imfloat.min()
im255 = ((imfloat - iMin) / (iMax - iMin) * 255).astype(np.uint8)
img = Image.fromarray(im255)
return img
[docs] @classmethod
@lru_cache
def getMrcImage(cls, fileName):
logger.info("Reading %s" % fileName)
return mrcfile.mmap(fileName, mode='r+')
[docs] @classmethod
def getCompatibleFileTypes(cls) -> list:
return ['mrc', 'mrc:mrc', 'mrcs', 'em', 'rec', 'ali', 'st', 'mrcs:mrc',
'mrcs:mrcs', 'mrc:mrcs']
[docs]class STKImageReader(ImageReader):
IMG_BYTES = None
stk_handler = None
header_info = None
HEADER_OFFSET = 1024
FLOAT32_BYTES = 4
TYPE = None
[docs] @classmethod
def open(cls, path):
stk = path.split('@')
if len(stk) > 1:
image = cls.read(stk[-1], int(stk[0]))
return image
[docs] @classmethod
def read(cls, filename, id):
"""
Reads a given image
:param filename (str) --> Image to be read
"""
cls.stk_handler = open(filename, "rb")
cls.header_info = cls.readHeader()
cls.IMG_BYTES = cls.FLOAT32_BYTES * cls.header_info["n_columns"] ** 2
image = cls.readImage(id - 1)
iMax = image.max()
iMin = image.min()
image = ((image - iMin) / (iMax - iMin) * 255).astype('uint8')
image = Image.fromarray(image)
return image
[docs] @classmethod
def readNumpy(cls, start, end):
"""
Read bytes between start and end as a Numpy array
:param start (int) --> Start byte
:param end (int) --> End byte
:returns decoded bytes as Numpy array
"""
return np.frombuffer(cls.readBinary(start, end), dtype=np.float32)
[docs] @classmethod
def readBinary(cls, start, end):
"""
Read bytes between start and end
:param start (int) --> Start byte
:param end (int) --> End byte
:returns the bytes read
"""
cls.seek(start)
return cls.stk_handler.read(end)
[docs] @classmethod
def readImage(cls, iid):
"""
Reads a given image in the stack according to its ID
:param iid (int) --> Image id to be read
:returns Image as Numpy array
"""
if cls.TYPE == "stack":
start = 2 * cls.header_info["offset"] + iid * (
cls.IMG_BYTES + cls.header_info["offset"])
else:
start = cls.header_info["offset"] + iid * cls.IMG_BYTES
img_size = cls.header_info["n_columns"]
return cls.readNumpy(start, cls.IMG_BYTES).reshape([img_size, img_size])
[docs] @classmethod
def seek(cls, pos):
"""
Move file pointer to a given position
:param pos (int) --> Byte to move the pointer to
"""
cls.stk_handler.seek(pos)
[docs] @classmethod
def getCompatibleFileTypes(cls) -> list:
return ['stk', 'vol']
ALLOWED_COLUMNS_TYPES = ['String', 'Float', 'Integer', 'Boolean', 'Matrix',
'CsvList']
ADITIONAL_INFO_DISPLAY_COLUMN_LIST = ['_size', 'id']
EXCLUDED_COLUMNS = ['label', 'comment', 'creation', '_streamState']
PERMANENT_COLUMNS = ['id', 'enabled']
CLASS_OBJECT = 1
REPRESENTATIVE_OBJECT = 2
CLASS_ELEMENTS = 3
EXTENDED_COLUMN_NAME = 'stack'
ENABLED_COLUMN = 'enabled'
PROPERTIES_TABLE = 'Properties'
OBJECT_TABLE = 'objects'
[docs]class SqliteFile(IDAO):
""" Class to manipulate Scipion Sqlite files. """
def __init__(self, sqliteFile):
self._names = []
self._file = sqliteFile
self._con = self.__loadDB(sqliteFile)
self._con.row_factory = self._dictFactory
self._tableCount = {}
self._tables = {}
self._labels = {}
self._labelsTypes = {}
self._aliases = {}
self._columnsMap = {}
self._extendedColumn = None
self._tableWithAdditionalInfo = None
self._objectsType = {}
def __loadDB(self, sqliteFile):
"""Load a sqlite file"""
try:
return sqlite3.connect(f"file:{sqliteFile}?mode=ro", uri=True)
except Exception as e:
logger.error("The file could not be opened. Make sure the path is "
"correct: \n %s" % e)
return None
[docs] def hasExtendedColumn(self):
"""Return if the table need to extend a column. That column is used to
renderer an image that is composed by other two columns"""
return self._extendedColumn is not None
[docs] def composeDataTables(self, tablesNames):
"""This method is used to generate a dictionary with the principal
tables mapping the dependencies with other tables"""
tablesNames = sorted(tablesNames)
for tableName in tablesNames:
divTable = tableName.split('_')
if len(divTable) > 1:
if divTable[-1].startswith('Class') and tableName not in self._tables:
objectTable = tableName.replace(divTable[-1], '') + 'Objects'
self._tables[objectTable] = tableName
self._names.append(objectTable)
self._tables[PROPERTIES_TABLE] = PROPERTIES_TABLE
self._names.append(PROPERTIES_TABLE)
[docs] def composeObjectType(self):
"""Define the different objects types"""
# General type defined into Properties table
firstRow = self.getTableRow(PROPERTIES_TABLE, 0)
objectType = firstRow['value']
self._objectsType[self._aliases[OBJECT_TABLE]] = objectType
for alias in self._aliases.values():
objectTypeAux = alias.split('_')
if len(objectTypeAux) == 2:
sufix = 's' if objectTypeAux[0][-1] in "aeiouAEIOU" or objectTypeAux[0][-1] != 's' else ''
objectType = 'SetOf%s%s' % (objectTypeAux[1], sufix)
if objectTypeAux[1] not in self._objectsType:
self._objectsType[objectTypeAux[1]] = objectType
[docs] def composeTableAlias(self, tableName):
"""Create an alias for the given table"""
if tableName != PROPERTIES_TABLE:
firstRow = self.getTableRow(tableName, 0)
className = firstRow['class_name']
if tableName.__contains__('_'):
tableSplit = tableName.split('_')
lenTableSplit = len(tableSplit)
if lenTableSplit > 1:
alias = tableName.replace(tableSplit[-1],'') + className
else:
alias = className
else:
alias = PROPERTIES_TABLE
return alias
[docs] def getTableNames(self):
""" Return all the table names found in the database. """
if not self._names:
self._tables = {OBJECT_TABLE: 'classes'}
self._names = [OBJECT_TABLE]
res = self._con.execute("SELECT name FROM sqlite_master WHERE type='table'")
tablesNames = [row['name'] for row in res.fetchall()]
self.composeDataTables(tablesNames)
for tableName in self._names:
# Getting the first row to identify the labels and theirs type
firstRow = self.getTableRow(tableName, 0, classes=self._tables[tableName])
self._labels[tableName] = [key for key in firstRow.keys() if key not in EXCLUDED_COLUMNS]
alias = self.composeTableAlias(self._tables[tableName])
self._aliases[tableName] = alias
labelsTypes = []
self._tableCount[tableName] = self.getRowsCount(tableName)
for key, value in firstRow.items():
if key not in EXCLUDED_COLUMNS:
labelsTypes.append(_guessType(value))
self._labelsTypes[tableName] = labelsTypes
self.composeObjectType()
if len(self._tables) > 2: # Assuming that there are more tables than just Object and Properties
self._tableWithAdditionalInfo = OBJECT_TABLE
return self._names
[docs] def findColbyName(self, colNames, colName):
"""Return a column index given a column name"""
for i, col in enumerate(colNames):
if colName == col:
return i
return None
[docs] def updateExtendColumn(self, table):
"""Find the columns that need to extend and keep the indexes"""
tableName = table.getName()
colNames = self._labels[tableName]
indexCol = self.findColbyName(colNames, '_index')
fileNameCol = self.findColbyName(colNames, '_filename')
if indexCol and fileNameCol:
logger.debug("The columns _index and _filename have been found. "
"We will proceed to create a new column with the "
"values of these columns.")
self._extendedColumn = indexCol, fileNameCol
else:
indexCol = self.findColbyName(colNames, '_representative._index')
fileNameCol = self.findColbyName(colNames,
'_representative._filename')
if indexCol and fileNameCol:
logger.debug("The columns _representative._index and "
"_representative._filename have been found. "
"We will proceed to create a new column with the "
"values of these columns.")
self._extendedColumn = indexCol, fileNameCol
[docs] def generateTableActions(self, table, objectManager):
"""Generate actions for a given table in order to create subsets"""
if self.getScipionPort() and table.getName() != PROPERTIES_TABLE:
alias = table.getAlias()
labels = list(self._objectsType.keys())
objectTypes = list(self._objectsType.values())
aliasSplit = alias.split('_')
if alias.startswith('Class') and len(aliasSplit) == 1:
table.addAction(labels[0], lambda: self.createSubsetCallback(table, objectTypes[0], objectManager))
table.addAction(labels[1], lambda: self.createSubsetCallback(table, objectTypes[1], objectManager))
if alias == 'Class2D':
table.addAction('Averages', lambda: self.createSubsetCallback(table, 'SetOfAverages', objectManager))
else:
table.addAction('Volumes', lambda: self.createSubsetCallback(table, 'SetOfVolumes', objectManager))
elif alias.startswith('Class'):
table.addAction(aliasSplit[1], lambda: self.createSubsetCallback(table, self._objectsType[aliasSplit[1]], objectManager))
elif alias in self._objectsType and self._objectsType[alias].startswith('SetOf'):
table.addAction(alias, lambda: self.createSubsetCallback(table, self._objectsType[alias], objectManager))
[docs] def fillTable(self, table, objectManager):
"""Create the table structure (columns) and set the table alias"""
tableName = table.getName()
colNames = self._labels[tableName]
self.updateExtendColumn(table)
row = self.getTableRow(tableName, 0, classes=self._tables[tableName])
if 'id' not in row:
table.setHasColumnId(False)
values = [value for key, value in row.items() if key not in EXCLUDED_COLUMNS]
for index, colName in enumerate(colNames):
isFileNameCol = self.hasExtendedColumn() and index == self._extendedColumn[1]
if colName == ENABLED_COLUMN:
renderer = BoolRenderer()
elif isFileNameCol:
renderer = StrRenderer()
else:
renderer = table.guessRenderer(str(values[index]))
if isinstance(renderer, ImageRenderer):
imageExt = str(values[index]).split('.')[-1]
self.addExternalProgram(renderer, imageExt)
newCol = Column(colName, renderer)
newCol.setIsSorteable(True)
newCol.setIsVisible(objectManager.isLabelVisible(colName))
table.addColumn(newCol)
if isFileNameCol:
logger.debug("Creating an extended column: %s" % EXTENDED_COLUMN_NAME)
imageExt = str(values[index]).split('.')[-1]
self.addExternalProgram(ImageRenderer(), imageExt)
extraCol = Column(colName, ImageRenderer())
extraCol.setIsVisible(newCol.isVisible())
extraCol.setIsSorteable(False)
table.addColumn(extraCol)
newCol.setIsVisible(False)
newCol.setName(EXTENDED_COLUMN_NAME)
table.setAlias(self._aliases[tableName])
self.generateTableActions(table, objectManager)
[docs] def addExternalProgram(self, renderer: ImageRenderer, imageExt: str):
self.addChimera(renderer, imageExt)
[docs] def addChimera(self, renderer: ImageRenderer, imageExt: str):
chimeraPath = os.environ.get('CHIMERA_HOME', None)
if chimeraPath is not None:
if imageExt not in ['st', 'stk']:
icon = pw.findResource('chimera.png')
def openChimeraCallback(path):
program = os.path.join(chimeraPath, 'bin', 'ChimeraX')
cmd = program + ' "%s"' % path
Popen(cmd, shell=True, cwd=os.getcwd())
renderer.addProgram(ExternalProgram('ChX', icon, 'ChimeraX', openChimeraCallback))
[docs] def fillPage(self, page, actualColumn=0, orderAsc=True):
"""
Read the given table from the sqlite and fill the page(add rows)
"""
tableName = page.getTable().getName()
# moving to the first row of the page
pageNumber = page.getPageNumber()
pageSize = page.getPageSize()
firstRow = pageNumber * pageSize - pageSize
limit = pageSize
column = self._labels[tableName][actualColumn]
mode = 'ASC' if orderAsc else 'DESC'
self.updateExtendColumn(page.getTable())
for rowcount, row in enumerate(self.iterTable(tableName, start=firstRow, limit=limit,
classes=self._tables[tableName],
orderBy=column, mode=mode)):
if row:
values = [value for key, value in row.items() if key not in EXCLUDED_COLUMNS]
if 'id' in row.keys():
id = row['id']
else:
id = rowcount
# Checking if exists an extended column
if self.hasExtendedColumn() and tableName != PROPERTIES_TABLE:
values.insert(self._extendedColumn[1] + 1,
str(values[self._extendedColumn[0]]) + '@' + values[self._extendedColumn[1]])
page.addRow((int(id), values))
[docs] def getRowsCount(self, tableName):
""" Return the number of elements in the given table. """
logger.debug("Reading the table %s" %tableName)
return self._con.execute(f"SELECT COUNT(*) FROM {tableName}").fetchone()['COUNT(*)']
[docs] def getTableRowCount(self, tableName):
return self._tableCount[tableName]
[docs] def getSelectedRangeRowsIds(self, tableName, startRow, numberOfRows, column, reverse=True):
"""Return a range of rows starting at 'startRow' an amount
of 'numberOfRows' """
logger.debug("Reading the table %s and selected a range of rows %d - %d" % (tableName,startRow, numberOfRows + 1))
mode = 'ASC' if reverse else 'DESC'
col = self._getColumnMap(tableName, column)
if col == None:
col = column
query = "SELECT id FROM %s ORDER BY %s %s LIMIT %d , %d" % (tableName, col, mode, startRow - 1, numberOfRows + 1)
rowsList = self._con.execute(query).fetchall()
rowsIds = [row['id'] for row in rowsList]
return rowsIds
[docs] def getColumnsValues(self, tableName, columns, xAxis, selection, limit,
useSelection, reverse=True):
"""Get the values of the selected columns in order to plot them"""
logger.debug("Reading the table %s and selected some columns values...")
cols = columns
if xAxis and xAxis not in cols:
cols.append(xAxis)
columnNames = []
for column in cols:
col = self._getColumnMap(tableName, column) or column
columnNames.append(col)
columnNames = ", ".join(columnNames)
col = self._getColumnMap(tableName, xAxis)
if col is not None:
xAxis = col
mode = 'ASC' if reverse else 'DESC'
orderBy = ' ORDER BY %s %s' % (xAxis, mode) if xAxis else ''
limit = ' LIMIT %d' % limit if limit is not None else ''
where = f" WHERE id in ({', '.join(map(str, selection.getSelection().keys()))})" if selection.getCount() > 1 and useSelection else ''
query = "SELECT %s FROM %s %s %s %s" % (columnNames, tableName, where,
orderBy, limit)
selectedColumns = self._con.execute(query).fetchall()
columnsValues = {}
firstValue = selectedColumns[0]
for colName in columns:
col = self._getColumnMap(tableName, colName) or colName
columnsValues[colName] = [firstValue[col]]
for pos, value in enumerate(selectedColumns):
if pos > 0:
for colName in columns:
col = self._getColumnMap(tableName, colName) or colName
columnsValues[colName].append(int(value[col]))
return columnsValues
[docs] def iterTable(self, tableName, **kwargs):
"""
Method to iterate over the table's rows
:param tableName: the name of the table
:param kwargs:
limit: integer value to limit the number of elements
start: start from a given element
classes: read column names from a 'classes' table
orderBy: clause to sort given a column name
mode: sort direction ASC or DESC
"""
query = f"SELECT * FROM {tableName}"
if 'mode' in kwargs:
if 'orderBy' in kwargs:
if kwargs['orderBy']:
column = self._getColumnMap(tableName, kwargs['orderBy'])
if not column:
column = kwargs['orderBy']
query += f" ORDER BY {column}"
if kwargs['mode']:
query += f" {kwargs['mode']}"
if 'start' in kwargs and 'limit' not in kwargs:
kwargs['limit'] = -1
if 'limit' in kwargs:
query += f" LIMIT {kwargs['limit']}"
if 'start' in kwargs:
query += f" OFFSET {kwargs['start']}"
if 'classes' not in kwargs or kwargs['classes'] == PROPERTIES_TABLE:
res = self._con.execute(query)
while row := res.fetchone():
yield row
else: # Mapping the column names and including only the allowed columns
self._columnsMap[tableName] = {row['column_name']: row['label_property']
for row in self.iterTable(kwargs['classes']) if row['class_name'] in ALLOWED_COLUMNS_TYPES}
self._excludedColumns = {row['column_name']: row['label_property']
for row in self.iterTable(kwargs['classes']) if row['class_name'] not in ALLOWED_COLUMNS_TYPES}
def _row_factory(cursor, row):
fields = [column[0] for column in cursor.description]
rowFact = {self._columnsMap[tableName].get(k, k): v for k, v in zip(fields, row) if k not in self._excludedColumns}
return rowFact
# Modify row factory to modify column names
self._con.row_factory = _row_factory
res = self._con.execute(query)
while row := res.fetchone():
yield row
# Restore row factory
self._con.row_factory = self._dictFactory
[docs] def getTableAliases(self):
"""Return the tables aliases"""
return self._aliases
def _getColumnMap(self, tableName, column):
"""Get the column name that has been mapped"""
if tableName in self._columnsMap:
for key, value in self._columnsMap[tableName].items():
if value == column:
return key
return None
[docs] def getTableRow(self, tableName, rowIndex, **kwargs):
""" Get a given row by index. Extra args are passed to iterTable. """
kwargs['start'] = rowIndex
kwargs['limit'] = 1
for row in self.iterTable(tableName, **kwargs):
return row
[docs] def getTableWithAdditionalInfo(self):
"""Return a tuple with the table that need to show additional info and
the column that we need to show"""
return self._tableWithAdditionalInfo, ADITIONAL_INFO_DISPLAY_COLUMN_LIST
[docs] def createSubsetCallback(self, table: Table, objectType: str, objectManager):
"""Create a subset"""
selection = table.getSelection().getSelection()
tableName = table.getName()
elementsCount = len(selection)
if not elementsCount:
elementsCount = self._tableCount[tableName]
subsetName = objectManager.getGui().getSubsetName(objectType, elementsCount)
if subsetName:
format = '%Y%m%d%H%M%S'
now = datetime.now()
timestamp = now.strftime(format)
path = 'Logs/selection_%s.txt' % timestamp
self.writeSelection(table, path)
path +="," # Always add a comma, it is expected by the user subset protocol
if tableName != OBJECT_TABLE:
path += tableName.split(OBJECT_TABLE)[0]
self.sendSubsetCreationMessage(path, objectType, subsetName)
[docs] def sendSubsetCreationMessage(self, selectionFile, outputClassName, label):
import socket
# Create a client socket
clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM);
# Connect to the local server where Scipion GUI is listening
clientSocket.connect(("127.0.0.1", int(self.getScipionPort())));
# We should create this message:
# run protocol ProtUserSubSet inputObject=380 sqliteFile='...','' outputClassName=SetOfTiltSeries other='' label='create subset'
data = f"run protocol ProtUserSubSet inputObject={self.getScipionObjectId()} " \
f"sqliteFile='{selectionFile}' outputClassName='{outputClassName}' other='' label='{label}'"
clientSocket.send(data.encode());
[docs] def getScipionPort(self):
""" Returns Scipion port or None if not in the environment"""
return os.getenv(SCIPION_PORT, '1300')
[docs] def getScipionObjectId(self):
""" Returns Scipion object id"""
return os.getenv(SCIPION_OBJECT_ID)
[docs] def writeSelection(self, table: Table, path):
""" Create a file with the selected rows ids"""
tableName = table.getName()
rowsIds = table.getSelection().getSelection().keys()
if not rowsIds:
rowsIds = [i+1 for i in range(self._tableCount[tableName])]
try:
with open(path, 'w') as file:
for rowId in rowsIds:
file.write(str(rowId) + ' ')
file.close()
logger.debug(f"The file: {path} was created correctly.")
except Exception as e:
logger.error(f"Error creating the file: {e}")
[docs] def close(self):
if getattr(self, '_con', None):
self._con.close()
self._con = None
def _dictFactory(self, cursor, row):
fields = [column[0] for column in cursor.description]
return {key: value for key, value in zip(fields, row)}
[docs] @classmethod
def getCompatibleFileTypes(cls):
"""Return a list of compatible extension of files"""
logger.debug("Selected SqliteFile DAO")
return ['sqlite']
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __del__(self):
self.close()
# --------- Helper functions ------------------------
def _guessType(strValue):
if strValue is None:
return str('None')
try:
int(strValue)
return int
except ValueError:
try:
float(strValue)
return float
except ValueError:
return str
[docs]def extendMDViewer(om:metadataviewer.model.ObjectManager):
""" Function to extend the object manager with DAOs and readers"""
om.registerDAO(SqliteFile)
om.registerReader(MRCImageReader)
om.registerReader(STKImageReader)