Source code for pyobo.ner.api

"""NER functionality."""

from __future__ import annotations

import logging
from collections.abc import Iterable
from subprocess import CalledProcessError
from typing import TYPE_CHECKING

import ssslm
from ssslm import LiteralMapping
from tqdm import tqdm
from typing_extensions import Unpack

from pyobo.api import get_literal_mappings
from pyobo.constants import GetOntologyKwargs, check_should_use_tqdm
from pyobo.getters import NoBuildError
from pyobo.struct import Reference

if TYPE_CHECKING:
    import gilda

__all__ = [
    "get_grounder",
]

logger = logging.getLogger(__name__)


[docs] def get_grounder( prefixes: str | Iterable[str], *, grounder_cls: type[gilda.Grounder] | None = None, versions: None | str | Iterable[str | None] | dict[str, str] = None, skip_obsolete: bool = False, raise_on_missing: bool = False, **kwargs: Unpack[GetOntologyKwargs], ) -> ssslm.Grounder[Reference]: """Get a grounder for the given prefix(es).""" all_literal_mappings: list[LiteralMapping[Reference]] = [] it = _clean_prefix_versions(prefixes, versions=versions) disable = len(it) == 1 or not check_should_use_tqdm(kwargs) for prefix, kwargs["version"] in tqdm(it, leave=False, disable=disable): try: literal_mappings = get_literal_mappings(prefix, skip_obsolete=skip_obsolete, **kwargs) except (NoBuildError, CalledProcessError) as e: logger.warning("[%s] unable to get literal mappings: %s", prefix, e) continue else: if not literal_mappings: if raise_on_missing: raise ValueError(f"no literal mappings were loaded for {prefix}") logger.warning("[%s] no literal mappings loaded", prefix) all_literal_mappings.extend(literal_mappings) return ssslm.make_grounder( all_literal_mappings, implementation="gilda", grounder_cls=grounder_cls )
def _clean_prefix_versions( prefixes: str | Iterable[str], versions: None | str | Iterable[str | None] | dict[str, str] = None, ) -> list[tuple[str, str | None]]: if isinstance(prefixes, str): prefixes = [prefixes] else: prefixes = list(prefixes) if versions is None: versions = [None] * len(prefixes) elif isinstance(versions, str): versions = [versions] elif isinstance(versions, dict): versions = [versions.get(prefix) for prefix in prefixes] else: versions = list(versions) if len(prefixes) != len(versions): raise ValueError return list(zip(prefixes, versions, strict=True))