Source code for perturb_lib.data.access

"""Copyright (C) 2025  GlaxoSmithKline plc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Interface module for all the data-related operations.
"""

from __future__ import annotations

import hashlib
import inspect
import json
import re
import shutil
import sys
from collections.abc import Sequence
from pathlib import Path
from typing import Callable, Literal, Optional, Self, Type, cast

import numpy as np
import polars as pl
from scanpy import AnnData

from perturb_lib.data import ControlSymbol
from perturb_lib.data.plibdata import InMemoryPlibData, OnDiskPlibData, PlibData
from perturb_lib.data.preprocesing import DEFAULT_PREPROCESSING_TYPE, PreprocessingType, preprocessors
from perturb_lib.data.utils import (
    ModelSystemType,
    TechnologyType,
    add_train_test_val_splits,
    anndata_format_verification,
)
from perturb_lib.environment import get_path_to_cache, logger
from perturb_lib.utils import update_symbols

context_catalogue: dict[str, type[AnnData]] = {}

SHARDSIZE = 200_000
PDATA_FORMAT_VERSION = 1  # Increase this number if the data format changes


class Vocabulary:
    """Vocabulary class for managing symbols of contexts, perturbations, and readouts."""

    def __init__(self, context_vocab: pl.DataFrame, perturb_vocab: pl.DataFrame, readout_vocab: pl.DataFrame):
        self._verify_vocabulary(context_vocab, "context_vocab")
        self._verify_vocabulary(perturb_vocab, "perturb_vocab")
        self._verify_vocabulary(readout_vocab, "readout_vocab")

        self.context_vocab: pl.DataFrame = context_vocab
        self.perturb_vocab: pl.DataFrame = perturb_vocab
        self.readout_vocab: pl.DataFrame = readout_vocab

        ctrl_symbol_expr = pl.col("symbol") == ControlSymbol
        self.perturb_control_code = self.perturb_vocab.row(by_predicate=ctrl_symbol_expr, named=True)["code"]

    @classmethod
    def initialize_from_symbols(
        cls: type[Self], context_symbols: list[str], perturb_symbols: list[str], readout_symbols: list[str]
    ) -> Self:
        """Initialize vocabulary using symbols given in the arguments."""
        # ensure existence of control symbol, and make sure its placed at the beginning
        if ControlSymbol in perturb_symbols:
            perturb_symbols.remove(ControlSymbol)
        perturb_symbols = [ControlSymbol] + perturb_symbols

        context_vocab = pl.DataFrame(data={"symbol": context_symbols, "code": np.arange(len(context_symbols))})
        perturb_vocab = pl.DataFrame(data={"symbol": perturb_symbols, "code": np.arange(len(perturb_symbols))})
        readout_vocab = pl.DataFrame(data={"symbol": readout_symbols, "code": np.arange(len(readout_symbols))})

        return cls(context_vocab, perturb_vocab, readout_vocab)

    @classmethod
    def initialize_from_data(cls: type[Self], data: AnnData | PlibData | pl.DataFrame) -> Self:
        """Initialize vocabulary using symbols given in the data."""
        if isinstance(data, AnnData):
            context_symbols = sorted(data.obs.context.unique())
            perturb_symbols = sorted(pl.Series(data.obs.perturbation).str.split("+").explode().unique())
            readout_symbols = sorted(data.var.readout.unique())
        elif isinstance(data, pl.DataFrame):
            context_symbols = sorted(data["context"].unique())
            perturb_symbols = sorted(data["perturbation"].str.split("+").explode().unique())
            readout_symbols = sorted(data["readout"].unique())
        elif isinstance(data, InMemoryPlibData):
            context_symbols = sorted(data._data["context"].unique())
            perturb_symbols = sorted(data._data["perturbation"].str.split("+").explode().unique())
            readout_symbols = sorted(data._data["readout"].unique())
        elif isinstance(data, OnDiskPlibData):
            context_symbols = sorted(data._data["context"].unique())
            perturb_symbols = sorted(data._data["perturbations"].explode().unique())
            readout_symbols = sorted(data._data["readouts"].explode().unique())
        else:
            raise ValueError("Wrong data format!")

        return cls.initialize_from_symbols(context_symbols, perturb_symbols, readout_symbols)

    @staticmethod
    def _verify_vocabulary(df: pl.DataFrame, name: str):
        if "symbol" not in df.columns:
            raise ValueError(f"Vocabulary component {name} must contain a column named 'symbol'.")
        if "code" not in df.columns:
            raise ValueError("Vocabulary component {name} must contain a column named 'code'.")
        if not df["symbol"].is_unique().all():
            raise ValueError(f"Vocabulary component {name} must contain unique symbols.")

    def __eq__(self, other) -> bool:
        """Check if two vocabularies are equal."""
        if not isinstance(other, Vocabulary):
            return NotImplemented

        return (
            self.context_vocab.equals(other.context_vocab)
            and self.perturb_vocab.equals(other.perturb_vocab)
            and self.readout_vocab.equals(other.readout_vocab)
        )


[docs] def list_contexts() -> list[str]: """List registered contexts. Returns: A list of context identifiers that are registered in perturb_lib. """ return sorted(list(context_catalogue.keys()))
def _verify_that_context_exists(context_id: str): if context_id not in list_contexts(): raise ValueError(f"Unavailable context {context_id}!")
[docs] def describe_context(context_id: str) -> Optional[str]: """Describe specified context. Args: context_id: Identifier of the context. Returns: The description of the context given as a string. Raises: ValueError: If given context is not registered. """ _verify_that_context_exists(context_id) return context_catalogue[context_id].__doc__
[docs] def load_anndata(context_id: str, hgnc_renaming: bool = False) -> AnnData: """Load data from given context as ``AnnData``. Args: context_id: Identifier of the context. hgnc_renaming: Whether to apply HNC renaming. Returns: AnnData object of the corresponding context. Raises: ValueError: If given context is not registered. """ _verify_that_context_exists(context_id) adata = context_catalogue[context_id]() if hgnc_renaming: adata = update_symbols(adata) perturbation_columns = [x for x in adata.obs.columns if "perturbation" in x] adata.obs["perturbation"] = adata.obs[perturbation_columns[0]] for column in perturbation_columns[1:]: adata.obs["perturbation"] = adata.obs["perturbation"] + "_" + adata.obs[column] adata.obs["context"] = context_id if "perturbation_type" in set(adata.obs.columns): pert_type = list(set(adata.obs.perturbation_type))[0] adata.obs["perturbation"] = adata.obs["perturbation"].str.replace("+", f"+{pert_type}_", regex=False) adata.obs["perturbation"] = adata.obs["perturbation"].str.replace(f"{pert_type}_{ControlSymbol}", ControlSymbol) readout_columns = [x for x in adata.var.columns if "readout" in x] adata.var["readout"] = adata.var[readout_columns].agg(lambda row: "_".join(row), axis=1) if "split" not in adata.obs.columns: add_train_test_val_splits(adata) adata.obs.reset_index(inplace=True) adata.var.set_index("readout", inplace=True) adata.var.reset_index(inplace=True) # for some reason, AnnData requires index to be of type str adata.var.index = adata.var.index.astype(str) adata.obs.index = adata.obs.index.astype(str) anndata_format_verification(adata) return adata
def _get_module_hash_for_context(context: str) -> str: """Get hash of the module where the context is defined.""" func = context_catalogue[context] source_file = inspect.getsourcefile(func) if source_file is None: raise ValueError(f"Could not find source file for context {context}") source_code = Path(source_file).read_text() return hashlib.md5(source_code.encode()).hexdigest() def _sanitize_context_name(context: str) -> str: """Sanitize context name by removing special characters.""" sanitized_name = re.sub(r'[ \'\]\[<>:"/\\|?*-]', "", context) return sanitized_name def _get_plib_data_source_name(context: str, preprocessing_type: PreprocessingType) -> str: """Get name of the plibdata source.""" return f"{_sanitize_context_name(context)}-{preprocessing_type}" def _process_context_and_generate_metadata( context: str, preprocessing_type: PreprocessingType = DEFAULT_PREPROCESSING_TYPE, ) -> pl.DataFrame: plibdata_dir = get_path_to_cache() / "plibdata" path_to_context = plibdata_dir / _get_plib_data_source_name(context, preprocessing_type) adata = load_anndata(context) adata = preprocessors[preprocessing_type](adata) logger.info("Casting to standardized DataFrame format..") ldf = pl.DataFrame( data=adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray(), schema=adata.var.readout.tolist(), ).lazy() ldf = ldf.with_columns( pl.Series("context", adata.obs.context.values).cast(pl.String), pl.Series("perturbation", adata.obs.perturbation.values).cast(pl.String), pl.Series("split", adata.obs.split.values).cast(pl.String), ) ldf = ldf.unpivot(index=["context", "perturbation", "split"], variable_name="readout") # Generate shard id per split ldf = ldf.with_columns(shard_id_per_split=pl.arange(pl.len()).over("split") // SHARDSIZE) # Generate global shard id ldf = ldf.with_row_index().with_columns(shard_id=pl.col("index").first().over("split", "shard_id_per_split")) ldf = ldf.with_columns( shard_id=(pl.col("shard_id").rank("dense") - 1) ) # -1 since rank generates numbers starting from 1 ldf = ldf.select("context", "perturbation", "readout", "value", "split", "shard_id") df = ldf.collect() metadata = df.group_by("shard_id", maintain_order=True).agg( pl.col("context").first(), pl.col("split").first(), pl.len().alias("size"), pl.col("perturbation").str.split("+").list.explode().unique().alias("perturbations"), pl.col("readout").unique().alias("readouts"), ) metadata = metadata.select( pl.format( "{}/shard_{}.parquet", pl.lit(path_to_context.name), pl.col("shard_id").cast(pl.String).str.zfill(6), ).alias("shard_path"), pl.col("size"), pl.col("split"), pl.col("context"), pl.col("perturbations"), pl.col("readouts"), ) for group_id, shard_df in df.group_by("shard_id"): shard_id = group_id[0] shard_path = path_to_context / f"shard_{shard_id:06d}.parquet" shard_df.drop("split", "shard_id").write_parquet(shard_path) return metadata def _process_and_cache_context( context: str, preprocessing_type: PreprocessingType = DEFAULT_PREPROCESSING_TYPE, ): plibdata_dir = get_path_to_cache() / "plibdata" path_to_context = plibdata_dir / _get_plib_data_source_name(context, preprocessing_type) info_file = path_to_context / "info.json" metadata_file = path_to_context / "metadata.parquet" context_module_hash = _get_module_hash_for_context(context) # Some sanity checks for the cache expected_info_dict = { "PDATA_FORMAT_VERSION": PDATA_FORMAT_VERSION, "CONTEXT_MODULE_HASH": context_module_hash, "SHARDSIZE": SHARDSIZE, } try: info = json.loads(info_file.read_text()) should_cache = info != expected_info_dict except (FileNotFoundError, json.JSONDecodeError): should_cache = True if not should_cache: try: expected_n_shards = len(pl.read_parquet(metadata_file)) current_n_shards = len(list(path_to_context.glob("shard_*.parquet"))) should_cache = current_n_shards != expected_n_shards except (FileNotFoundError, ValueError): should_cache = True if should_cache: shutil.rmtree(path_to_context, ignore_errors=True) path_to_context.mkdir(exist_ok=True, parents=True) metadata = _process_context_and_generate_metadata(context, preprocessing_type) metadata.write_parquet(metadata_file) info_file.write_text(json.dumps(expected_info_dict)) def _unpack_context_ids(contexts_ids: str | Sequence[str]) -> list[str]: if isinstance(contexts_ids, str): contexts_ids = [contexts_ids] expanded_contexts_ids = [] for context_id in contexts_ids: if context_id not in context_catalogue: # check if context_id is a substring that matches actual context ids context_id_list = [c for c in list_contexts() if context_id in c] if not context_id_list: # trigger an exception _verify_that_context_exists(context_id) else: # add all contexts that matched the substring expanded_contexts_ids.extend(context_id_list) else: expanded_contexts_ids.append(context_id) return sorted(list(set(expanded_contexts_ids)))
[docs] def load_plibdata( contexts_ids: str | Sequence[str], preprocessing_type: PreprocessingType = DEFAULT_PREPROCESSING_TYPE, plibdata_type: type[PlibData] = InMemoryPlibData, ) -> PlibData: """Load data from given context(s) as ``PlibData``. Note: Triggers data caching. Args: contexts_ids: Either a list of context identifiers that will be stacked together, or a single context id. If no match is found for an identifier, all identifiers that contain that substring will be included. preprocessing_type: The type of preprocessing to apply. plibdata_type: The type of ``PlibData`` to use. Returns: An instance of a ``PlibData`` that has columns [context, perturbation, readout, value]. """ contexts_ids = _unpack_context_ids(contexts_ids) plibdata_dir = get_path_to_cache() / "plibdata" for context in contexts_ids: _process_and_cache_context(context, preprocessing_type) data_sources = [_get_plib_data_source_name(context, preprocessing_type) for context in contexts_ids] return plibdata_type(data_sources=data_sources, path_to_data_sources=plibdata_dir)
def _split_plibdata[PlibDataT: PlibData]( pdata: PlibDataT, context_ids: str | Sequence[str] | None, split_name: Literal["train", "val", "test"], ) -> tuple[PlibDataT, PlibDataT]: """Split PlibData into two parts Args: pdata: PlibData to split. context_ids: Contexts from which validation perturbations should be selected. If None, use all contexts. split_name: The name of the split for the second return argument Returns: Two ``PlibData`` instances. The second instance will contain all samples belonging to the given ``split_name`` while the first instance will contain all remaining samples. """ split_expr = pl.col("split") == split_name if context_ids is not None: context_ids = _unpack_context_ids(context_ids) split_expr = split_expr & (pl.col("context").is_in(context_ids)) # noinspection PyProtectedMember data_or_metadata = pdata._data split_data = data_or_metadata.filter(split_expr) remaining_data = data_or_metadata.filter(~split_expr) split_pdata = type(pdata)(data=split_data) remaining_pdata = type(pdata)(data=remaining_data) return remaining_pdata, split_pdata
[docs] def split_plibdata_2fold[PlibDataT: PlibData]( pdata: PlibDataT, context_ids: str | Sequence[str] | None ) -> tuple[PlibDataT, PlibDataT]: """Split data to training and validation. Args: pdata: An instance of ``PlibData`` to split. context_ids: Contexts from which validation perturbations should be selected. If None, a portion of validation data from each context will be taken. Returns: Three instances of ``PlibData``: training and validation. """ logger.debug("Splitting plibdata instance into train and validation.") train_pdata, val_pdata = _split_plibdata(pdata=pdata, context_ids=context_ids, split_name="val") return train_pdata, val_pdata
[docs] def split_plibdata_3fold[PlibDataT: PlibData]( pdata: PlibDataT, context_ids: str | Sequence[str] | None ) -> tuple[PlibDataT, PlibDataT, PlibDataT]: """Split data to training, validation, and test. Args: pdata: An instance of ``PlibData`` to split. context_ids: Contexts from which test and validation perturbations should be selected. Returns: Three instances of ``PlibData``: training, validation, and test. """ logger.debug("Splitting plibdata instance into train, val, and test") remaining_pdata, test_pdata = _split_plibdata(pdata=pdata, context_ids=context_ids, split_name="test") train_pdata, val_pdata = _split_plibdata(pdata=remaining_pdata, context_ids=context_ids, split_name="val") return train_pdata, val_pdata, test_pdata
def _encode_polars_df(data: pl.DataFrame, vocab: Vocabulary) -> pl.DataFrame: """Encode data using given vocabulary. Args: data: data to encode vocab: vocabulary object Returns: encoded data. """ data = data.with_columns( context=pl.col("context").replace_strict(vocab.context_vocab["symbol"], vocab.context_vocab["code"]), perturbation=pl.col("perturbation") .str.split("+") .list.eval( pl.element().replace_strict( vocab.perturb_vocab["symbol"], vocab.perturb_vocab["code"], default=vocab.perturb_control_code ) ), readout=pl.col("readout").replace_strict(vocab.readout_vocab["symbol"], vocab.readout_vocab["code"]), ) return data class _PolarsDFEncoder: """Encode data using a given vocabulary.""" def __init__(self, vocab: Vocabulary): self.vocab = vocab def __call__(self, data: pl.DataFrame) -> pl.DataFrame: return _encode_polars_df(data, self.vocab) def encode_data[ T: ( pl.DataFrame, PlibData[pl.DataFrame], InMemoryPlibData[pl.DataFrame], OnDiskPlibData[pl.DataFrame], AnnData, ) ](data: T, vocab: Vocabulary) -> T: """Encode data using given vocabulary. Args: data: data to encode vocab: vocabulary object Returns: encoded data. """ if isinstance(data, AnnData): data_copy = data.copy() # to ensure we don't modify the original data object obs_df: pl.DataFrame = pl.from_pandas(data_copy.obs) var_df: pl.DataFrame = pl.from_pandas(data_copy.var) obs_df = obs_df.with_columns( context=pl.col("context").replace_strict(vocab.context_vocab["symbol"], vocab.context_vocab["code"]), perturbation=pl.col("perturbation") .str.split("+") .list.eval( pl.element().replace_strict( vocab.perturb_vocab["symbol"], vocab.perturb_vocab["code"], default=vocab.perturb_control_code ) ), ) var_df = var_df.with_columns( readout=pl.col("readout").replace_strict(vocab.readout_vocab["symbol"], vocab.readout_vocab["code"]) ) data_copy.obs = obs_df.to_pandas() data_copy.var = var_df.to_pandas() data = data_copy elif isinstance(data, PlibData): data = data.apply_transform(_PolarsDFEncoder(vocab)) elif isinstance(data, pl.DataFrame): data = _encode_polars_df(data, vocab) else: raise ValueError("Wrong data format!") return data
[docs] def register_context(context_class: type[AnnData]): """Register new context to the collection. Example:: import perturb_lib as plib import scanpy as sc import numpy as np @plib.register_context class CoolPerturbationScreen(AnnData): def __init__(self): super().__init__( X=np.zeros((3, 5), dtype=np.float32), obs={"perturbation": ["P1", "P2", "P3"]}, var={"readout": ["R1", "R2", "R3", "R4", "R5"]}, ) Args: context_class: context class to register Raises: ValueError: If a context class with the same name exists already. """ context = context_class.__uncut_name__ if hasattr(context_class, "__uncut_name__") else context_class.__name__ if context in list_contexts() and ( getattr(context_class, "__module__", "") != getattr(context_catalogue[context], "__module__", "") or context_class.__qualname__ != context_catalogue[context].__qualname__ ): raise ValueError(f"Existing id {context} already registered for a different context!") context_catalogue[context] = context_class return context_class
def create_and_register_context( model_system: ModelSystemType, model_system_id: str | None, technology_info: TechnologyType, data_source_info: str, batch_info: str | None, full_context_description: str, anndata_fn: Callable, **anndata_fn_kwargs, ) -> Type[AnnData]: """Context creation factory. The created context class will be inserted in the namespace of the module where the anndata_fn is defined. Args: model_system: One of the pre-specified descriptions of the model system. model_system_id: Community-recognized identifier of the model system, e.g. K562 for a cell line. technology_info: Information about technology used to create the data. data_source_info: Identifier of the data source. batch_info: Either None or some description of a batch. full_context_description: Detailed description of the context. anndata_fn: Function used to instantiate constructor arguments for creation of the context object. anndata_fn_kwargs: Kwargs for anndata_fn. """ model_system_id = f"_{model_system_id}" if model_system_id is not None else "" batch_info = f"_{batch_info}" if batch_info is not None else "" uncut_name = f"{model_system}{model_system_id}_{technology_info}_{data_source_info}{batch_info}" class_name = _sanitize_context_name(uncut_name) source_module_name: str = anndata_fn.__module__ class_dict = { "__doc__": full_context_description, "__init__": lambda self: super(context, self).__init__(*anndata_fn(**anndata_fn_kwargs)), # type: ignore "__uncut_name__": uncut_name, "__module__": source_module_name, } context: Type[AnnData] = cast(Type[AnnData], type(class_name, (AnnData,), class_dict)) setattr(sys.modules[context.__module__], context.__name__, context) # Insert the class into the module's namespace register_context(context) return context