import functools
import json
import logging
import os
import time
from collections.abc import Iterator
from typing import Any, Optional, Union, cast
import fsspec
import pyarrow as pa
import pyarrow.parquet as pq
from dae.annotation.annotation_pipeline import AttributeInfo
from dae.parquet.helpers import url_to_pyarrow_fs
from dae.parquet.partition_descriptor import PartitionDescriptor
from dae.parquet.schema2.serializers import AlleleParquetSerializer
from dae.utils import fs_utils
from dae.utils.variant_utils import (
is_all_reference_genotype,
is_unknown_genotype,
)
from dae.variants.attributes import Inheritance
from dae.variants.family_variant import FamilyAllele, FamilyVariant
from dae.variants.variant import SummaryAllele, SummaryVariant
logger = logging.getLogger(__name__)
[docs]class ContinuousParquetFileWriter:
"""A continous parquet writer.
Class that automatically writes to a given parquet file when supplied
enough data. Automatically dumps leftover data when closing into the file
"""
BATCH_ROWS = 1_000
DEFAULT_COMPRESSION = "SNAPPY"
def __init__(
self,
filepath: str,
annotation_schema: list[AttributeInfo],
filesystem: Optional[fsspec.AbstractFileSystem] = None,
row_group_size: int = 50_000,
schema: str = "schema",
blob_column: Optional[str] = None,
) -> None:
self.filepath = filepath
self.annotation_schema = annotation_schema
self.serializer = AlleleParquetSerializer(
self.annotation_schema,
)
self.schema = getattr(self.serializer, schema)
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
self.dirname = dirname
filesystem, filepath = url_to_pyarrow_fs(filepath, filesystem)
compression: Union[str, dict[str, str]] = self.DEFAULT_COMPRESSION
if blob_column is not None:
compression = {}
for name in self.schema.names:
compression[name] = self.DEFAULT_COMPRESSION
compression[blob_column] = "ZSTD"
self._writer = pq.ParquetWriter(
filepath, self.schema,
compression=compression,
filesystem=filesystem,
use_compliant_nested_type=True,
write_page_index=True,
)
self.row_group_size = row_group_size
self._batches: list[pa.RecordBatch] = []
self._data: Optional[dict[str, Any]] = None
self.data_reset()
[docs] def data_reset(self) -> None:
self._data = {name: [] for name in self.schema.names}
[docs] def size(self) -> int:
assert self._data is not None
# min_len = min(len(val) for val in self._data.values())
# max_len = max(len(val) for val in self._data.values())
# assert min_len == max_len
return len(self._data["bucket_index"])
[docs] def build_table(self) -> pa.Table:
logger.info(
"writing %s rows to parquet %s",
sum(len(b) for b in self._batches),
self.filepath)
table = pa.Table.from_batches(self._batches, self.schema)
return table
[docs] def build_batch(self) -> pa.RecordBatch:
return pa.RecordBatch.from_pydict(self._data, self.schema)
def _write_batch(self) -> None:
if self.size() == 0:
return
batch = self.build_batch()
self._batches.append(batch)
self.data_reset()
if len(self._batches) >= self.row_group_size // self.BATCH_ROWS:
self._flush_batches()
def _flush_batches(self) -> None:
if len(self._batches) == 0:
return
logger.debug(
"flushing %s batches", len(self._batches))
self._writer.write_table(self.build_table())
self._batches = []
[docs] def append_summary_allele(
self, allele: SummaryAllele, json_data: str) -> None:
"""Append the data for an entire variant to the correct file."""
assert self._data is not None
data = self.serializer.build_summary_allele_batch_dict(
allele, json_data,
)
for k, v in self._data.items():
v.append(data[k])
if self.size() >= self.BATCH_ROWS:
logger.debug(
"parquet writer %s create summary batch at len %s",
self.filepath, self.size())
self._write_batch()
[docs] def append_family_allele(
self, allele: FamilyAllele, json_data: str) -> None:
"""Append the data for an entire variant to the correct file."""
assert self._data is not None
data = self.serializer.build_family_allele_batch_dict(
allele, json_data,
)
for k, v in self._data.items():
v.extend(data[k])
if self.size() >= self.BATCH_ROWS:
logger.debug(
"parquet writer %s create family batch at len %s",
self.filepath, self.size())
self._write_batch()
[docs] def close(self) -> None:
"""Close the parquet writer and write any remaining data."""
logger.debug(
"closing parquet writer %s with %d rows",
self.filepath, self.size())
self._write_batch()
self._flush_batches()
self._writer.close()
[docs]class VariantsParquetWriter:
"""Provide functions for storing variants into parquet dataset."""
def __init__(
self,
out_dir: str,
annotation_schema: list[AttributeInfo],
partition_descriptor: PartitionDescriptor,
bucket_index: int = 1,
row_group_size: int = 50_000,
include_reference: bool = True,
filesystem: Optional[fsspec.AbstractFileSystem] = None,
) -> None:
self.out_dir = out_dir
self.bucket_index = bucket_index
assert self.bucket_index < 1_000_000, "bad bucket index"
self.row_group_size = row_group_size
self.filesystem = filesystem
self.include_reference = include_reference
self.start = time.time()
self.data_writers: dict[str, ContinuousParquetFileWriter] = {}
assert isinstance(partition_descriptor, PartitionDescriptor)
self.partition_descriptor = partition_descriptor
self.annotation_schema = annotation_schema
def _build_family_filename(
self, allele: FamilyAllele,
seen_as_denovo: bool,
) -> str:
partition = self.partition_descriptor.family_partition(
allele, seen_as_denovo)
partition_directory = self.partition_descriptor.partition_directory(
fs_utils.join(self.out_dir, "family"), partition)
partition_filename = self.partition_descriptor.partition_filename(
"family", partition, self.bucket_index)
return fs_utils.join(partition_directory, partition_filename)
def _build_summary_filename(
self, allele: SummaryAllele,
seen_as_denovo: bool,
) -> str:
partition = self.partition_descriptor.summary_partition(
allele, seen_as_denovo)
partition_directory = self.partition_descriptor.partition_directory(
fs_utils.join(self.out_dir, "summary"), partition)
partition_filename = self.partition_descriptor.partition_filename(
"summary", partition, self.bucket_index)
return fs_utils.join(partition_directory, partition_filename)
def _get_bin_writer_family(
self, allele: FamilyAllele,
seen_as_denovo: bool,
) -> ContinuousParquetFileWriter:
filename = self._build_family_filename(allele, seen_as_denovo)
if filename not in self.data_writers:
self.data_writers[filename] = ContinuousParquetFileWriter(
filename,
self.annotation_schema,
filesystem=self.filesystem,
row_group_size=self.row_group_size,
schema="schema_family",
blob_column="family_variant_data",
)
return self.data_writers[filename]
def _get_bin_writer_summary(
self, allele: SummaryAllele,
seen_as_denovo: bool,
) -> ContinuousParquetFileWriter:
filename = self._build_summary_filename(allele, seen_as_denovo)
if filename not in self.data_writers:
self.data_writers[filename] = ContinuousParquetFileWriter(
filename,
self.annotation_schema,
filesystem=self.filesystem,
row_group_size=self.row_group_size,
schema="schema_summary",
blob_column="summary_variant_data",
)
return self.data_writers[filename]
def _calc_sj_index(self, summary_index: int, allele_index: int) -> int:
assert allele_index < 10_000, "too many alleles"
sj_index = (
self.bucket_index * 1_000_000_000
+ summary_index) * 10_000 + allele_index
return sj_index
def _calc_sj_base_index(self, summary_index: int) -> int:
sj_index = (
self.bucket_index * 1_000_000_000
+ summary_index) * 10_000
return sj_index
[docs] def write_dataset(
self,
full_variants_iterator: Iterator[
tuple[SummaryVariant, list[FamilyVariant]]],
) -> list[str]:
"""Write variant to partitioned parquet dataset."""
# pylint: disable=too-many-locals,too-many-branches
family_variant_index = 0
summary_variant_index = 0
for summary_variant_index, (
summary_variant,
family_variants,
) in enumerate(full_variants_iterator):
assert summary_variant_index < 1_000_000_000, \
"too many summary variants"
num_fam_alleles_written = 0
seen_in_status = summary_variant.allele_count * [0]
seen_as_denovo = summary_variant.allele_count * [False]
family_variants_count = summary_variant.allele_count * [0]
sj_base_index = self._calc_sj_base_index(summary_variant_index)
for fv in family_variants:
family_variant_index += 1
assert fv.gt is not None
if is_all_reference_genotype(fv.gt) and \
not self.include_reference:
continue
fv.summary_index = summary_variant_index
fv.family_index = family_variant_index
allele_indexes = set()
for fa in fv.alleles:
assert fa.allele_index not in allele_indexes
allele_indexes.add(fa.allele_index)
extra_atts = {
"bucket_index": self.bucket_index,
"family_index": family_variant_index,
"sj_index": sj_base_index + fa.allele_index,
}
fa.update_attributes(extra_atts)
family_variant_data_json = json.dumps(fv.to_record(),
sort_keys=True)
family_alleles = []
if is_unknown_genotype(fv.gt) or is_all_reference_genotype(fv.gt):
assert fv.ref_allele.allele_index == 0
family_alleles.append(fv.ref_allele)
num_fam_alleles_written += 1
elif self.include_reference:
family_alleles.append(fv.ref_allele)
family_alleles.extend(fv.alt_alleles)
for aa in family_alleles:
fa = cast(FamilyAllele, aa)
seen_in_status[fa.allele_index] = functools.reduce(
lambda t, s: t | s.value,
filter(None, fa.allele_in_statuses),
seen_in_status[fa.allele_index])
inheritance = list(
filter(
lambda v: v not in {
None,
Inheritance.unknown, Inheritance.missing},
fa.inheritance_in_members))
sad = any(
i == Inheritance.denovo
for i in inheritance)
seen_as_denovo[fa.allele_index] = \
sad or seen_as_denovo[fa.allele_index]
family_bin_writer = self._get_bin_writer_family(fa, sad)
family_bin_writer.append_family_allele(
fa, family_variant_data_json,
)
family_variants_count[fa.allele_index] += 1
num_fam_alleles_written += 1
# don't store summary alleles withouth family ones
if num_fam_alleles_written > 0:
summary_variant.summary_index = summary_variant_index
summary_variant.ref_allele.update_attributes(
{"bucket_index": self.bucket_index})
summary_variant.update_attributes({
"seen_in_status": seen_in_status[1:],
"seen_as_denovo": seen_as_denovo[1:],
"family_variants_count": family_variants_count[1:],
"family_alleles_count": family_variants_count[1:],
"bucket_index": [self.bucket_index],
})
self.write_summary_variant(
summary_variant, sj_base_index=sj_base_index,
)
if summary_variant_index % 1000 == 0 and summary_variant_index > 0:
elapsed = time.time() - self.start
logger.info(
"progress bucked %s; "
"summary variants: %s; family variants: %s; "
"elapsed time: %0.2f sec",
self.bucket_index,
summary_variant_index, family_variant_index,
elapsed)
filenames = list(self.data_writers.keys())
self.close()
elapsed = time.time() - self.start
logger.info(
"finished bucked %s; summary variants: %s; family variants: %s; "
"elapsed time: %0.2f sec",
self.bucket_index, summary_variant_index, family_variant_index,
elapsed)
return filenames
[docs] def close(self) -> None:
for bin_writer in self.data_writers.values():
bin_writer.close()
[docs] def write_summary_variant(
self, summary_variant: SummaryVariant,
attributes: Optional[dict[str, Any]] = None,
sj_base_index: Optional[int] = None,
) -> None:
"""Write a single summary variant to the correct parquet file."""
if attributes is not None:
summary_variant.update_attributes(attributes)
if sj_base_index is not None:
for summary_allele in summary_variant.alleles:
sj_index = sj_base_index + summary_allele.allele_index
extra_atts = {
"sj_index": sj_index,
}
summary_allele.update_attributes(extra_atts)
summary_blobs_json = json.dumps(
summary_variant.to_record(), sort_keys=True,
)
if self.include_reference:
stored_alleles = summary_variant.alleles
else:
stored_alleles = summary_variant.alt_alleles
for summary_allele in stored_alleles:
seen_as_denovo = summary_allele.get_attribute("seen_as_denovo")
summary_writer = self._get_bin_writer_summary(
summary_allele, seen_as_denovo)
summary_writer.append_summary_allele(
summary_allele, summary_blobs_json)