from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Dict, Optional, Union, cast
import pandas as pd
from box import Box
from sqlalchemy import (
Column,
Float,
Integer,
MetaData,
String,
Table,
create_engine,
func,
or_,
)
from sqlalchemy.sql import select
from sqlalchemy.sql.schema import PrimaryKeyConstraint, UniqueConstraint
from dae.pheno.common import MeasureType
from dae.variants.attributes import Status
[docs]class PhenoDb: # pylint: disable=too-many-instance-attributes
"""Class that manages access to phenotype databases."""
STREAMING_CHUNK_SIZE = 25
def __init__(
self, dbfile: str, read_only: bool = True,
) -> None:
# self.verify_pheno_folder(folder)
self.dbfile = dbfile
self.engine = create_engine(
f"duckdb:///{dbfile}", connect_args={"read_only": read_only},
)
self.pheno_metadata = MetaData()
self.variable_browser: Table
self.regressions: Table
self.regression_values: Table
self.family: Table
self.person: Table
self.measure: Table
self.instrument: Table
self.instrument_values_tables: dict[str, Table] = {}
[docs] @staticmethod
def verify_pheno_folder(folder: Path) -> None:
"""Verify integrity of a pheno db folder."""
parquet_folder = folder / "parquet"
assert parquet_folder.exists()
family_file = parquet_folder / "family.parquet"
assert family_file.exists()
person_file = parquet_folder / "person.parquet"
assert person_file.exists()
instrument_file = parquet_folder / "instrument.parquet"
assert instrument_file.exists()
measure_file = parquet_folder / "measure.parquet"
assert measure_file.exists()
instruments_dir = parquet_folder / "instruments"
assert instruments_dir.exists()
assert instruments_dir.is_dir()
assert len(list(instruments_dir.glob("*"))) > 0
[docs] def build_browser(self) -> None:
self._build_browser_tables()
[docs] def build(self, create: bool = False) -> None:
"""Construct all needed table connections."""
self._build_person_tables()
self.build_instruments_and_measures_table()
if create:
self.pheno_metadata.create_all(self.engine)
self.build_instrument_values_tables()
self.build_browser()
if create:
self.pheno_metadata.create_all(self.engine)
[docs] def create_all_tables(self) -> None:
self.pheno_metadata.create_all(self.engine)
[docs] def build_instruments_and_measures_table(self) -> None:
"""Create tables for instruments and measures."""
if getattr(self, "instruments", None) is None:
self.instrument = Table(
"instrument",
self.pheno_metadata,
Column(
"instrument_name", String(64), nullable=False, index=True,
),
Column("table_name", String(64), nullable=False),
)
if getattr(self, "measure", None) is None:
self.measure = Table(
"measure",
self.pheno_metadata,
Column(
"measure_id",
String(128),
nullable=False,
index=True,
unique=True,
),
Column(
"db_column_name",
String(128),
nullable=False,
),
Column("measure_name", String(64), nullable=False, index=True),
Column("instrument_name", String(64), nullable=False),
Column("description", String(255)),
Column("measure_type", Integer(), index=True),
Column("individuals", Integer()),
Column("default_filter", String(255)),
Column("min_value", Float(), nullable=True),
Column("max_value", Float(), nullable=True),
Column("values_domain", String(255), nullable=True),
Column("rank", Integer(), nullable=True),
)
[docs] def build_instrument_values_tables(self) -> None:
"""
Create instrument values tables.
Each row is basically a list of every measure value in the instrument
for a certain person.
"""
query = select(
self.instrument.c.instrument_name,
self.instrument.c.table_name,
)
with self.engine.connect() as connection:
instruments_rows = connection.execute(query)
instrument_table_names = {}
instrument_measures: dict[str, list[str]] = {}
for row in instruments_rows:
instrument_table_names[row.instrument_name] = row.table_name
instrument_measures[row.instrument_name] = []
query = select(
self.measure.c.measure_id,
self.measure.c.measure_type,
self.measure.c.db_column_name,
self.instrument.c.instrument_name,
).join(
self.instrument,
self.measure.c.instrument_name == self.instrument.c.instrument_name,
)
with self.engine.connect() as connection:
results = connection.execute(query)
measure_columns = {}
for result_row in results:
instrument_measures[result_row.instrument_name].append(
result_row.measure_id,
)
if MeasureType.is_numeric(result_row.measure_type):
column_type: Union[Float, String] = Float()
else:
column_type = String(127)
measure_columns[result_row.measure_id] = \
Column(
f"{result_row.db_column_name}",
column_type, nullable=True,
)
for instrument_name, table_name in instrument_table_names.items():
cols = [
measure_columns[m_id]
for m_id in
instrument_measures[instrument_name]
]
if instrument_name not in self.instrument_values_tables:
self.instrument_values_tables[instrument_name] = Table(
table_name,
self.pheno_metadata,
Column(
"person_id",
String(16),
nullable=False,
index=True,
unique=True,
primary_key=True,
),
Column(
"family_id", String(64), nullable=False, index=True,
),
Column("role", String(64), nullable=False, index=True),
Column(
"status",
Integer(),
nullable=False,
default=Status.unaffected,
),
Column("sex", Integer(), nullable=False),
*cols,
extend_existing=True,
)
def _split_measures_into_groups(
self, measure_ids: list[str], group_size: int = 60,
) -> list[list[str]]:
groups_count = int(len(measure_ids) / group_size) + 1
if (groups_count) == 1:
return [measure_ids]
measure_groups = []
for i in range(groups_count):
begin = i * group_size
end = (i + 1) * group_size
group = measure_ids[begin:end]
if len(group) > 0:
measure_groups.append(group)
return measure_groups
[docs] def clear_instruments_table(self, drop: bool = False) -> None:
"""Clear the instruments table."""
if getattr(self, "instruments", None) is None:
return
with self.engine.begin() as connection:
connection.execute(self.instrument.delete())
if drop:
self.instrument.drop(connection, checkfirst=False)
connection.commit()
[docs] def clear_measures_table(self, drop: bool = False) -> None:
"""Clear the measures table."""
if getattr(self, "measures", None) is None:
return
with self.engine.begin() as connection:
connection.execute(self.measure.delete())
if drop:
self.measure.drop(connection, checkfirst=False)
connection.commit()
[docs] def clear_instrument_values_tables(self, drop: bool = False) -> None:
"""Clear all instrument values tables."""
if getattr(self, "instrument_values_tables", None) is None:
return
with self.engine.begin() as connection:
for instrument_table in self.instrument_values_tables.values():
connection.execute(instrument_table.delete())
if drop:
instrument_table.drop(connection, checkfirst=False)
connection.commit()
[docs] def get_instrument_column_names(self) -> dict[str, list[str]]:
"""Return a map of instruments and their measure column names."""
query = select(
self.measure.c.db_column_name,
self.instrument.c.instrument_name,
).join(self.instrument)
with self.engine.connect() as connection:
results = connection.execute(query)
instrument_col_names = {}
for result_row in results:
if result_row.instrument_name not in instrument_col_names:
instrument_col_names[result_row.instrument_name] = [
result_row.db_column_name,
]
else:
instrument_col_names[result_row.instrument_name].append(
result_row.db_column_name,
)
return instrument_col_names
[docs] def get_measure_column_names(
self, measure_ids: Optional[list[str]] = None,
) -> dict[str, str]:
"""Return measure column names mapped to their measure IDs."""
query = select(
self.measure.c.measure_id,
self.measure.c.db_column_name,
)
if measure_ids is not None:
query = query.where(self.measure.c.measure_id.in_(measure_ids))
with self.engine.connect() as connection:
results = connection.execute(query)
measure_column_names = {}
for result_row in results:
measure_column_names[result_row.measure_id] = \
result_row.db_column_name
return measure_column_names
[docs] def get_measure_column_names_reverse(
self, measure_ids: Optional[list[str]] = None,
) -> dict[str, str]:
"""Return measure column names mapped to their measure IDs."""
query = select(
self.measure.c.measure_id,
self.measure.c.db_column_name,
)
if measure_ids is not None:
query = query.where(self.measure.c.measure_id.in_(measure_ids))
with self.engine.connect() as connection:
results = connection.execute(query)
measure_column_names = {}
for result_row in results:
measure_column_names[result_row.db_column_name] = \
result_row.measure_id
return measure_column_names
def _build_browser_tables(self) -> None:
self.variable_browser = Table(
"variable_browser",
self.pheno_metadata,
Column(
"measure_id",
String(128),
nullable=False,
index=True,
unique=True,
primary_key=True,
),
Column("instrument_name", String(64), nullable=False, index=True),
Column("measure_name", String(64), nullable=False, index=True),
Column("measure_type", Integer(), nullable=False),
Column("description", String(256)),
Column("values_domain", String(256)),
Column("figure_distribution_small", String(256)),
Column("figure_distribution", String(256)),
)
self.regressions = Table(
"regression",
self.pheno_metadata,
Column(
"regression_id",
String(128),
nullable=False,
index=True,
primary_key=True,
),
Column("instrument_name", String(128)),
Column("measure_name", String(128), nullable=False),
Column("display_name", String(256)),
)
self.regression_values = Table(
"regression_values",
self.pheno_metadata,
Column("regression_id", String(128), nullable=False, index=True),
Column("measure_id", String(128), nullable=False, index=True),
Column("figure_regression", String(256)),
Column("figure_regression_small", String(256)),
Column("pvalue_regression_male", Float()),
Column("pvalue_regression_female", Float()),
PrimaryKeyConstraint(
"regression_id", "measure_id", name="regression_pkey",
),
)
def _build_person_tables(self) -> None:
self.family = Table(
"family",
self.pheno_metadata,
Column(
"family_id",
String(64),
nullable=False,
unique=True,
index=True,
),
)
self.person = Table(
"person",
self.pheno_metadata,
Column("family_id", String(64), nullable=False),
Column("person_id", String(16), nullable=False, index=True),
Column("role", Integer(), nullable=False),
Column(
"status",
Integer(),
nullable=False,
default=Status.unaffected,
),
Column("sex", Integer(), nullable=False),
Column("sample_id", String(16), nullable=True),
UniqueConstraint("family_id", "person_id", name="person_key"),
)
[docs] def save(self, v: Dict[str, Optional[str]]) -> None:
"""Save measure values into the database."""
try:
insert = self.variable_browser.insert().values(**v)
with self.engine.begin() as connection:
connection.execute(insert)
connection.commit()
except Exception: # pylint: disable=broad-except
measure_id = v["measure_id"]
delete = (
self.variable_browser.delete()
.where(self.variable_browser.c.measure_id == measure_id)
)
with self.engine.connect() as connection:
connection.execute(delete)
connection.commit()
with self.engine.connect() as connection:
connection.execute(insert)
connection.commit()
[docs] def save_regression(self, reg: Dict[str, str]) -> None:
"""Save regressions into the database."""
try:
insert = self.regressions.insert().values(reg)
with self.engine.begin() as connection:
connection.execute(insert)
except Exception: # pylint: disable=broad-except
regression_id = reg["regression_id"]
del reg["regression_id"]
update = (
self.regressions.update()
.values(reg)
.where(self.regressions.c.regression_id == regression_id)
)
with self.engine.begin() as connection:
connection.execute(update)
connection.commit()
[docs] def save_regression_values(self, reg: Dict[str, str]) -> None:
"""Save regression values into the databases."""
try:
insert = self.regression_values.insert().values(reg)
with self.engine.begin() as connection:
connection.execute(insert)
except Exception: # pylint: disable=broad-except
regression_id = reg["regression_id"]
measure_id = reg["measure_id"]
del reg["regression_id"]
del reg["measure_id"]
update = (
self.regression_values.update()
.values(reg)
.where(
(self.regression_values.c.regression_id == regression_id)
& (self.regression_values.c.measure_id == measure_id),
)
)
with self.engine.begin() as connection:
connection.execute(update)
connection.commit()
[docs] def get_browser_measure(self, measure_id: str) -> Optional[dict]:
"""Get measrue description from phenotype browser database."""
sel = select(self.variable_browser)
sel = sel.where(self.variable_browser.c.measure_id == measure_id)
with self.engine.connect() as connection:
vs = connection.execute(sel).fetchall()
if vs:
return Box(cast(dict, vs[0]._asdict()))
return None
[docs] def search_measures(
self, instrument_name: Optional[str] = None,
keyword: Optional[str] = None,
) -> Iterator[dict[str, Any]]:
"""Find measert by keyword search."""
query_params = []
if keyword:
keyword = keyword.replace("%", r"/%").replace("_", r"/_")
keyword = f"%{keyword}%"
if not instrument_name:
query_params.append(
self.variable_browser.c.instrument_name.ilike(
keyword, escape="/",
),
)
query_params.append(
self.variable_browser.c.measure_id.ilike(keyword, escape="/"),
)
query_params.append(
self.variable_browser.c.measure_name.ilike(keyword, escape="/"),
)
query_params.append(
self.variable_browser.c.description.ilike(keyword, escape="/"),
)
query = self.variable_browser.select().where(or_(*query_params))
else:
query = self.variable_browser.select()
if instrument_name:
query = query.where(
self.variable_browser.c.instrument_name == instrument_name,
)
with self.engine.connect() as connection:
cursor = connection.execution_options(stream_results=True)\
.execute(query)
rows = cursor.fetchmany(self.STREAMING_CHUNK_SIZE)
while rows:
for row in rows:
yield {
"measure_id": row[0],
"instrument_name": row[1],
"measure_name": row[2],
"measure_type": MeasureType(row[3]),
"description": row[4],
"values_domain": row[5],
"figure_distribution_small": row[6],
"figure_distribution": row[7],
}
rows = cursor.fetchmany(self.STREAMING_CHUNK_SIZE)
[docs] def search_measures_df(
self, instrument_name: Optional[str] = None,
keyword: Optional[str] = None,
) -> pd.DataFrame:
"""Find measures and return a dataframe with values."""
query_params = []
if keyword:
keyword = keyword.replace("%", r"/%").replace("_", r"/_")
keyword = f"%{keyword}%"
if not instrument_name:
query_params.append(
self.variable_browser.c.instrument_name.like(
keyword, escape="/",
),
)
query_params.append(
self.variable_browser.c.measure_id.like(keyword, escape="/"),
)
query_params.append(
self.variable_browser.c.measure_name.like(keyword, escape="/"),
)
query_params.append(
self.variable_browser.c.description.like(keyword, escape="/"),
)
query = self.variable_browser.select().where(or_(*query_params))
else:
query = self.variable_browser.select()
if instrument_name:
query = query.where(
self.variable_browser.c.instrument_name == instrument_name,
)
df = pd.read_sql(query, self.engine)
return df
[docs] def get_regression(self, regression_id: str) -> Any:
"""Return regressions."""
selector = select(self.regressions)
selector = selector.where(
self.regressions.c.regression_id == regression_id)
with self.engine.connect() as connection:
vs = connection.execute(selector).fetchall()
if vs:
return vs[0]._mapping # pylint: disable=protected-access
return None
[docs] def get_regression_values(self, measure_id: str) -> list[Box]:
selector = select(self.regression_values)
selector = selector.where(
self.regression_values.c.measure_id == measure_id)
with self.engine.connect() as connection:
return [
Box(r._asdict())
for r in connection.execute(selector).fetchall()
]
@property
def regression_ids(self) -> list[str]:
selector = select(self.regressions.c.regression_id)
with self.engine.connect() as connection:
return list(map(
lambda x: x[0],
connection.execute(selector)))
@property
def regression_display_names(self) -> Dict[str, str]:
"""Return regressions display name."""
res = {}
selector = select(
self.regressions.c.regression_id, self.regressions.c.display_name,
)
with self.engine.connect() as connection:
for row in connection.execute(selector):
res[row[0]] = row[1]
return res
@property
def regression_display_names_with_ids(self) -> dict[str, Any]:
"""Return regression display names with measure IDs."""
res = {}
selector = select(
self.regressions.c.regression_id,
self.regressions.c.display_name,
self.regressions.c.instrument_name,
self.regressions.c.measure_name,
)
with self.engine.connect() as connection:
for row in connection.execute(selector):
res[row[0]] = {
"display_name": row[1],
"instrument_name": row[2],
"measure_name": row[3],
}
return res
@property
def has_descriptions(self) -> bool:
"""Check if the database has a description data."""
with self.engine.connect() as connection:
return bool(
connection.execute(
select(func.count()) # pylint: disable=not-callable
.select_from(self.variable_browser)
.where(Column("description").isnot(None)),
).scalar(),
)
[docs] def get_families(self) -> dict:
"""Return families in the phenotype database."""
value_type = select(self.family)
with self.engine.connect() as connection:
families = connection.execute(value_type).fetchall()
return {f.family_id: f for f in families}
[docs] def get_persons(self) -> dict:
"""Return individuals in the phenotype database."""
selector = select(
self.person.c.person_id,
self.person.c.family_id,
self.person.c.role,
self.person.c.status,
self.person.c.sex,
)
with self.engine.connect() as connection:
persons = connection.execute(selector).fetchall()
return {p.person_id: p for p in persons}
[docs] def get_measures(self) -> dict:
"""Return measures in the phenotype database."""
selector = select(
self.measure.c.measure_id,
self.measure.c.instrument_name,
self.measure.c.measure_name,
self.measure.c.measure_type,
)
selector = selector.select_from(self.measure)
with self.engine.begin() as connection:
measures = connection.execute(selector).fetchall()
return {m.measure_id: m for m in measures}
[docs]def safe_db_name(name: str) -> str:
name = name.replace(".", "_").replace("-", "_").replace(" ", "_").lower()
name = name.replace("/", "_")
if name[0].isdigit():
name = f"_{name}"
return name
[docs]def generate_instrument_table_name(instrument_name: str) -> str:
instrument_name = safe_db_name(instrument_name)
return f"{instrument_name}_measure_values"