Source code for dae.query_variants.sql.schema2.base_query_builder

import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, Optional, Union, cast

import dae.utils.regions
from dae.genomic_resources.gene_models import GeneModels
from dae.pedigrees.families_data import FamiliesData
from dae.query_variants.attributes_query import (
    LeafNode,
    QueryTransformerMatcher,
    QueryTreeToSQLBitwiseTransformer,
    TreeNode,
    inheritance_query,
    role_query,
    sex_query,
    variant_type_query,
)
from dae.query_variants.attributes_query_inheritance import (
    InheritanceTransformer,
    inheritance_parser,
)
from dae.utils.regions import Region
from dae.variants.attributes import Inheritance

logger = logging.getLogger(__name__)
RealAttrFilterType = list[tuple[str, tuple[Optional[float], Optional[float]]]]


[docs]class Dialect(ABC): """Caries info about a SQL dialect.""" def __init__(self, namespace: Optional[str] = None): # namespace, self.namespace = namespace
[docs] @staticmethod def use_bit_and_function() -> bool: return True
[docs] @staticmethod def add_unnest_in_join() -> bool: return False
[docs] @staticmethod def float_type() -> str: return "float"
[docs] @staticmethod def array_item_suffix() -> str: return ".item"
[docs] @staticmethod def int_type() -> str: return "int"
[docs] @staticmethod def escape_char() -> str: return "`"
[docs] @staticmethod def escape_quote_char() -> str: return "\\"
[docs] def build_table_name(self, table: str, db: Optional[str]) -> str: return f"`{self.namespace}`.{db}.{table}" if self.namespace else \ f"{db}.{table}"
[docs] def build_array_join(self, column: str, allias: str) -> str: return f"\n JOIN\n {column} AS {allias}"
# A type describing a schema as expected by the query builders TableSchema = dict[str, str] # family_variant_table & summary_allele_table are mandatory # - no reliance on a variants table as in impala
[docs]class BaseQueryBuilder(ABC): """Class that abstracts away the process of building a query.""" # pylint: disable=too-many-instance-attributes QUOTE = "'" WHERE = """ WHERE {where} """ GENE_REGIONS_HEURISTIC_CUTOFF = 20 GENE_REGIONS_HEURISTIC_EXTEND = 20000 MAX_CHILD_NUMBER = 9999 def __init__( self, dialect: Dialect, db: Optional[str], family_variant_table: Optional[str], summary_allele_table: str, pedigree_table: str, family_variant_schema: Optional[TableSchema], summary_allele_schema: TableSchema, partition_descriptor: Optional[dict], pedigree_schema: TableSchema, families: FamiliesData, gene_models: Optional[GeneModels] = None, ): # pylint: disable=too-many-arguments assert summary_allele_table is not None self.dialect = dialect self.db = db self.family_variant_table = family_variant_table self.summary_allele_table = summary_allele_table self.pedigree_table = pedigree_table self.partition_descriptor = partition_descriptor if not family_variant_schema: family_variant_schema = {} self.family_columns = family_variant_schema.keys() self.summary_columns = summary_allele_schema.keys() self.combined_columns = { **family_variant_schema, **summary_allele_schema, } self.pedigree_columns = pedigree_schema self.families = families self.has_extra_attributes = "extra_attributes" in self.combined_columns self._product = "" self.gene_models = gene_models self.query_columns = self._query_columns() self.where_accessors = self._where_accessors() def _where_accessors(self) -> dict[str, str]: cols = list(self.family_columns) + list(self.summary_columns) accessors = dict(zip(cols, cols)) family_keys = set(self.family_columns) summary_keys = set(self.summary_columns) for key, value in accessors.items(): if value in summary_keys: accessors[key] = f"sa.{value}" elif value in family_keys: accessors[key] = f"fa.{value}" return accessors
[docs] def build_query( self, regions: Optional[list[Region]] = None, genes: Optional[list[str]] = None, effect_types: Optional[list[str]] = None, family_ids: Optional[Iterable[str]] = None, person_ids: Optional[Iterable[str]] = None, inheritance: Optional[Union[str, list[str]]] = None, roles: Optional[str] = None, sexes: Optional[str] = None, variant_type: Optional[str] = None, real_attr_filter: Optional[RealAttrFilterType] = None, ultra_rare: Optional[bool] = None, frequency_filter: Optional[RealAttrFilterType] = None, return_reference: Optional[bool] = None, return_unknown: Optional[bool] = None, limit: Optional[int] = None, pedigree_fields: Optional[tuple] = None, ) -> str: # pylint: disable=too-many-arguments,too-many-locals,unused-argument """Build an SQL query in the correct order.""" self._product = "" self._build_select() self._build_from() self._build_join(genes=genes, effect_types=effect_types) self._build_where( regions=regions, genes=genes, effect_types=effect_types, family_ids=family_ids, person_ids=person_ids, inheritance=inheritance, roles=roles, sexes=sexes, variant_type=variant_type, real_attr_filter=real_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, pedigree_fields=pedigree_fields, ) self._build_group_by() self._build_having() self._build_limit(limit) return self._product
def _build_select(self) -> None: columns = ", ".join(self.query_columns) select_clause = f"SELECT {columns}" self._add_to_product(select_clause) def _build_from(self) -> None: pass def _build_join( self, genes: Optional[list[str]], effect_types: Optional[list[str]], ) -> None: pass def _build_where_pedigree_fields( self, pedigree_fields: Optional[tuple], ) -> str: # pylint: disable=unused-argument return "" def _build_where( self, regions: Optional[list[Region]] = None, genes: Optional[list[str]] = None, effect_types: Optional[list[str]] = None, family_ids: Optional[Iterable[str]] = None, person_ids: Optional[Iterable[str]] = None, inheritance: Optional[Union[str, list[str]]] = None, roles: Optional[str] = None, sexes: Optional[str] = None, variant_type: Optional[str] = None, real_attr_filter: Optional[RealAttrFilterType] = None, ultra_rare: Optional[bool] = None, frequency_filter: Optional[RealAttrFilterType] = None, return_reference: Optional[bool] = None, return_unknown: Optional[bool] = None, pedigree_fields: Optional[tuple] = None, **kwargs: Any, ) -> None: # pylint: disable=too-many-arguments,too-many-locals,unused-argument where_clause = self._build_where_string( regions=regions, genes=genes, effect_types=effect_types, family_ids=family_ids, person_ids=person_ids, inheritance=inheritance, roles=roles, sexes=sexes, variant_type=variant_type, real_attr_filter=real_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, pedigree_fields=pedigree_fields, ) self._add_to_product(where_clause) def _build_where_string( self, regions: Optional[list[Region]] = None, genes: Optional[list[str]] = None, effect_types: Optional[list[str]] = None, family_ids: Optional[Iterable[str]] = None, person_ids: Optional[Iterable[str]] = None, inheritance: Optional[Union[str, list[str]]] = None, roles: Optional[str] = None, sexes: Optional[str] = None, variant_type: Optional[str] = None, real_attr_filter: Optional[RealAttrFilterType] = None, ultra_rare: Optional[bool] = None, frequency_filter: Optional[RealAttrFilterType] = None, return_reference: Optional[bool] = None, return_unknown: Optional[bool] = None, pedigree_fields: Optional[tuple] = None, **kwargs: Any, ) -> str: # pylint: disable=too-many-arguments,too-many-branches,unused-argument where = [] if genes is not None and effect_types is not None: regions = self._build_gene_regions_heuristic(genes, regions) clause = self._build_iterable_struct_string_attr_where( ["eg.effect_gene_symbols", "eg.effect_types"], [genes, effect_types], ) where.append(f"({clause})") if genes is not None and effect_types is None: regions = self._build_gene_regions_heuristic(genes, regions) # effect gene is a struct under affect gene in V2 schema where.append( self._build_iterable_struct_string_attr_where( ["eg.effect_gene_symbols"], [genes], ), ) if effect_types is not None and genes is None: # effect gene is a struct under affect gene in V2 schema where.append( self._build_iterable_struct_string_attr_where( ["eg.effect_types"], [effect_types], ), ) if regions is not None: where.append(self._build_regions_where(regions)) if family_ids is not None: where.append( self._build_iterable_string_attr_where( self.where_accessors["family_id"], family_ids, ), ) if person_ids is not None: person_ids = set(person_ids) where.append( self._build_iterable_string_attr_where( "pi" + self.dialect.array_item_suffix(), person_ids), ) if inheritance is not None: where.extend( self._build_inheritance_where( self.where_accessors["inheritance_in_members"], inheritance, self.dialect.use_bit_and_function(), ), ) if roles is not None: where.append( self._build_bitwise_attr_where( self.where_accessors["allele_in_roles"], roles, role_query, ), ) if sexes is not None: where.append( self._build_bitwise_attr_where( self.where_accessors["allele_in_sexes"], sexes, sex_query, ), ) if variant_type is not None: where.append( self._build_bitwise_attr_where( self.where_accessors["variant_type"], variant_type, variant_type_query, ), ) if real_attr_filter is not None: where.append(self._build_real_attr_where(real_attr_filter)) if frequency_filter is not None: where.append( self._build_real_attr_where( frequency_filter, is_frequency=True, ), ) if ultra_rare: where.append(self._build_ultra_rare_where(ultra_rare)) where.append( self._build_return_reference_and_return_unknown( return_reference, return_unknown, ), ) where.append( self._build_frequency_bin_heuristic( inheritance, ultra_rare, frequency_filter, ), ) where.append(self._build_family_bin_heuristic(family_ids, person_ids)) where.append(self._build_coding_heuristic(effect_types)) where.append(self._build_region_bin_heuristic(regions)) where.append(self._build_where_pedigree_fields(pedigree_fields)) where = [w for w in where if w] where_clause = "" if where: where_clause = self.WHERE.format( where=(" AND \n" + " " * 4).join( [f"( {w} )" for w in where], ), ) return where_clause @abstractmethod def _build_group_by(self) -> None: pass def _build_limit(self, limit: Optional[int]) -> None: if limit is not None: self._add_to_product(f"LIMIT {limit}") @abstractmethod def _build_having(self, **kwargs: Any) -> None: pass def _add_to_product(self, query_part: Optional[str]) -> None: if query_part is None or query_part == "": return if self._product == "": self._product += query_part else: self._product += f" {query_part}" @abstractmethod def _query_columns(self) -> list[str]: pass def _build_real_attr_where( self, real_attr_filter: RealAttrFilterType, is_frequency: bool = False, ) -> str: query = [] for attr_name, attr_range in real_attr_filter: if attr_name not in self.combined_columns: query.append("false") continue assert attr_name in self.combined_columns assert ( self.combined_columns[attr_name] == self.dialect.float_type() or self.combined_columns[attr_name].startswith( self.dialect.int_type()) ), f"{attr_name} - {self.combined_columns}" left, right = attr_range attr_name = self.where_accessors[attr_name] if left is None and right is None: if not is_frequency: query.append(f"({attr_name} is not null)") elif left is None: assert right is not None if is_frequency: query.append( f"({attr_name} <= {right} or {attr_name} is null)", ) else: query.append( f"({attr_name} <= {right})", ) elif right is None: assert left is not None query.append(f"({attr_name} >= {left})") else: query.append( f"({attr_name} >= {left} AND {attr_name} <= {right})", ) return " AND ".join(query) def _build_ultra_rare_where(self, ultra_rare: bool) -> str: assert ultra_rare return self._build_real_attr_where( real_attr_filter=[("af_allele_count", (None, 1))], is_frequency=True, ) def _build_regions_where(self, regions: list[Region]) -> str: assert isinstance(regions, list), regions where = [] for region in regions: assert isinstance(region, Region) esc = self.dialect.escape_char() end_position = f"COALESCE(sa.{esc}end_position{esc}, -1)" query = "( sa.{esc}chromosome{esc} = {q}{chrom}{q}" if region.start is None and region.end is None: query += " )" query = query.format( q=self.QUOTE, chrom=region.chrom, esc=esc, ) else: query += ( " AND " "(" "(sa.{esc}position{esc} >= {start} AND " "sa.{esc}position{esc} <= {stop}) " "OR " "({end_position} >= {start} AND " "{end_position} <= {stop}) " "OR " "({start} >= sa.{esc}position{esc} AND " "{stop} <= {end_position})" "))" ) query = query.format( q=self.QUOTE, chrom=region.chrom, start=region.start, stop=region.stop, end_position=end_position, esc=esc, ) where.append(query) return " OR ".join(where) def _build_iterable_struct_string_attr_where( self, key_names: Optional[Iterable[str]] = None, query_values: Optional[Iterable[Iterable[str]]] = None, ) -> str: key_names = key_names if key_names else [] query_values = query_values if query_values else [] inner_clauses = [ self._build_iterable_string_attr_where(tup[0], tup[1]) for tup in zip(key_names, query_values) ] where_clause = " AND ".join(inner_clauses) return where_clause def _build_iterable_string_attr_where( self, column_name: str, query_values: Iterable[str], ) -> str: assert query_values is not None assert isinstance(query_values, (list, set)), type(query_values) if not query_values: return f" {column_name} IS NULL" values = [ " {q}{val}{q} ".format( q=self.QUOTE, val=val.replace("'", self.dialect.escape_quote_char() + "'"), ) for val in query_values ] where: list[str] = [] for i in range(0, len(values), self.MAX_CHILD_NUMBER): chunk_values = values[i: i + self.MAX_CHILD_NUMBER] in_expr = f" {column_name} in ( {','.join(chunk_values)} ) " where.append(in_expr) where_clause = " OR ".join([f"( {w} )" for w in where]) return where_clause def _build_bitwise_attr_where( self, column_name: str, query_value: str, query_transformer: QueryTransformerMatcher, ) -> str: assert query_value is not None parsed: Union[str, LeafNode, TreeNode] = query_value if isinstance(query_value, str): parsed = query_transformer.transform_query_string_to_tree( query_value, ) transformer = QueryTreeToSQLBitwiseTransformer( column_name, self.dialect.use_bit_and_function(), ) return cast(str, transformer.transform(parsed)) @staticmethod def _build_inheritance_where( column_name: str, query_value: Union[str, list[str]], use_bit_and_function: bool, ) -> list[str]: trees = [] if isinstance(query_value, str): tree = inheritance_parser.parse(query_value) trees.append(tree) elif isinstance(query_value, list): for qval in query_value: tree = inheritance_parser.parse(qval) trees.append(tree) # raise ValueError() else: tree = query_value trees.append(tree) result = [] for tree in trees: transformer = InheritanceTransformer( column_name, use_bit_and_function=use_bit_and_function) res = transformer.transform(tree) result.append(res) return result def _build_gene_regions_heuristic( self, genes: list[str], regions: Optional[list[Region]], ) -> Optional[list[Region]]: assert genes is not None assert self.gene_models is not None if len(genes) == 0 or len(genes) > self.GENE_REGIONS_HEURISTIC_CUTOFF: return regions gene_regions = [] for gene_name in genes: gene_model = self.gene_models.gene_models_by_gene_name(gene_name) if gene_model is None: logger.warning("gene model for %s not found", gene_name) continue for gm in gene_model: gene_regions.append( Region( gm.chrom, max(1, gm.tx[0] - self.GENE_REGIONS_HEURISTIC_EXTEND), gm.tx[1] + self.GENE_REGIONS_HEURISTIC_EXTEND, ), ) gene_regions = dae.utils.regions.collapse(gene_regions) if not regions: regions = gene_regions else: result = [] for gene_region in gene_regions: for region in regions: intersection = gene_region.intersection(region) if intersection: result.append(intersection) result = dae.utils.regions.collapse(result) logger.info("original regions: %s; result: %s", regions, result) regions = result return regions def _build_partition_bin_heuristic_where( self, bin_column: str, bins: Union[list[str], set[str]], number_of_possible_bins: Optional[int] = None, str_bins: bool = False, ) -> str: if len(bins) == 0: return "" if number_of_possible_bins is not None and \ len(bins) == number_of_possible_bins: return "" cols = [] if bin_column in self.family_columns: cols.append("fa." + bin_column) if bin_column in self.summary_columns: cols.append("sa." + bin_column) if str_bins: bins_str = ",".join([f"'{rb}'" for rb in bins]) else: bins_str = ",".join([f"{rb}" for rb in bins]) parts = [f"{col} IN ({bins_str})" for col in cols] return " AND ".join(parts) def _build_frequency_bin_heuristic_compute_bins( self, inheritance: Optional[Union[str, list[str]]], ultra_rare: Optional[bool], frequency_filter: Optional[RealAttrFilterType], rare_boundary: float, ) -> set[str]: frequency_bins: set[str] = set([]) matchers = [] if inheritance is not None: logger.debug( "frequence_bin_heuristic inheritance: %s (%s)", inheritance, type(inheritance), ) if isinstance(inheritance, str): inheritance = [inheritance] matchers = [ inheritance_query.transform_tree_to_matcher( inheritance_query.transform_query_string_to_tree(inh), ) for inh in inheritance ] if any(m.match([Inheritance.denovo]) for m in matchers): frequency_bins.add("0") has_frequency_filter = False if frequency_filter: for name, (begin, end) in frequency_filter: if name == "af_allele_freq": has_frequency_filter = True break if inheritance is None or any( m.match( [ Inheritance.mendelian, Inheritance.possible_denovo, Inheritance.possible_omission, ], ) for m in matchers ): if ultra_rare: frequency_bins |= set(["0", "1"]) elif has_frequency_filter: assert frequency_filter is not None for name, (begin, end) in frequency_filter: if name == "af_allele_freq": if end and end < rare_boundary: frequency_bins |= set(["0", "1", "2"]) elif begin and begin >= rare_boundary: frequency_bins |= set(["0", "1", "3"]) elif end is not None and end >= rare_boundary: frequency_bins |= set(["0", "1", "2", "3"]) elif inheritance is not None: frequency_bins |= set(["0", "1", "2", "3"]) return frequency_bins def _build_frequency_bin_heuristic( self, inheritance: Union[None, str, list[str]], ultra_rare: Optional[bool], frequency_filter: Optional[RealAttrFilterType], ) -> str: # pylint: disable=too-many-branches assert self.partition_descriptor is not None if "frequency_bin" not in self.combined_columns: return "" rare_boundary = self.partition_descriptor["rare_boundary"] frequency_bins = self._build_frequency_bin_heuristic_compute_bins( inheritance, ultra_rare, frequency_filter, rare_boundary) return self._build_partition_bin_heuristic_where( "frequency_bin", frequency_bins, 4) # cols = [] # if "frequency_bin" in self.family_columns: # cols.append("fa.frequency_bin") # if "frequency_bin" in self.summary_columns: # cols.append("sa.frequency_bin") # bins_str = ",".join([f"{rb}" for rb in frequency_bins]) # parts = [f"{col} IN ({bins_str})" for col in cols] # return " AND ".join(parts) def _build_coding_heuristic( self, effect_types: Union[None, set[str], list[str]], ) -> str: assert self.partition_descriptor is not None if effect_types is None: return "" if "coding_bin" not in self.combined_columns: return "" effect_types = set(effect_types) intersection = effect_types & set( self.partition_descriptor["coding_effect_types"], ) logger.debug( "coding bin heuristic: query effect types: %s; " "coding_effect_types: %s; => %s", effect_types, self.partition_descriptor["coding_effect_types"], intersection == effect_types, ) coding_bins = set([]) if intersection == effect_types: coding_bins.add("1") elif not intersection: coding_bins.add("0") return self._build_partition_bin_heuristic_where( "coding_bin", coding_bins, 2) def _build_region_bin_heuristic( self, regions: Optional[list[Region]], ) -> str: assert self.partition_descriptor is not None if not regions or self.partition_descriptor["region_length"] == 0: return "" chroms = set(self.partition_descriptor["chromosomes"]) region_length = self.partition_descriptor["region_length"] region_bins = [] for region in regions: if region.chrom in chroms: chrom_bin = region.chrom else: chrom_bin = "other" if region.start is None and region.end is None: continue start = region.start // region_length stop = region.stop // region_length for position_bin in range(start, stop + 1): region_bins.append(f"{chrom_bin}_{position_bin}") return self._build_partition_bin_heuristic_where( "region_bin", region_bins, str_bins=True) # region_bin_col = self.where_accessors["region_bin"] # cols = [] # if "region_bin" in self.family_columns: # cols.append("fa.region_bin") # if "region_bin" in self.summary_columns: # cols.append("sa.region_bin") # bins_str = ",".join([f"'{rb}'" for rb in region_bins]) # parts = [f"{col} IN ({bins_str})" for col in cols] # return " AND ".join(parts) # return f"{region_bin_col} IN ({bins_str})" def _build_family_bin_heuristic( self, family_ids: Optional[Iterable[str]], person_ids: Optional[Iterable[str]], ) -> str: assert self.partition_descriptor is not None if "family_bin" not in self.combined_columns: return "" if "family_bin" not in self.pedigree_columns: return "" family_bins: set[str] = set() if family_ids: family_ids = set(family_ids) family_bins = family_bins.union( set( p.family_bin # type: ignore for p in self.families.persons.values() if p.family_id in family_ids ), ) if person_ids: person_ids = set(person_ids) family_bins = family_bins.union( set( p.family_bin # type: ignore for p in self.families.persons.values() if p.person_id in person_ids ), ) return self._build_partition_bin_heuristic_where( "family_bin", family_bins, self.partition_descriptor["family_bin_size"]) def _build_return_reference_and_return_unknown( self, return_reference: Optional[bool] = None, return_unknown: Optional[bool] = None, ) -> str: # pylint: disable=unused-argument allele_index_col = self.where_accessors["allele_index"] if not return_reference: return f"{allele_index_col} > 0" # return_unknown basically means return all so no specific where # expression is required return ""