Source code for roman_datamodels._stnode._converters

"""
The ASDF Converters to handle the serialization/deseialization of the STNode classes to ASDF.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from asdf.extension import Converter
from astropy.time import Time

from ._registry import (
    LIST_NODE_CLASSES_BY_PATTERN,
    MANIFEST_TAG_REGISTRY,
    NODE_CLASSES_BY_TAG,
    NODE_CONVERTERS,
    OBJECT_NODE_CLASSES_BY_PATTERN,
    SCALAR_NODE_CLASSES_BY_PATTERN,
    SERIALIZATION_BY_MANIFEST,
    TAG_MANIFEST_REGISTRY,
)

if TYPE_CHECKING:
    from ._tagged import SerializationNode, TaggedListNode, TaggedObjectNode, TaggedScalarNode

__all__ = [
    "TaggedListNodeConverter",
    "TaggedObjectNodeConverter",
    "TaggedScalarNodeConverter",
]


class _RomanConverter(Converter):
    """
    Base class for the roman_datamodels converters.
    """

    lazy = True


class SerializationNodeConverter(_RomanConverter):
    """
    Converter that tags are deferred to so that the correct
    extension can be applied
    """

    def __init__(self, manifest_uri: str):
        self._manifest_uri = manifest_uri

    def select_tag(self, obj: SerializationNode, tags, ctx) -> str:
        return obj.tag

    @property
    def tags(self) -> tuple[str, ...]:
        return tuple(MANIFEST_TAG_REGISTRY[self._manifest_uri])

    @property
    def types(self) -> tuple[type[SerializationNode], ...]:
        return (SERIALIZATION_BY_MANIFEST[self._manifest_uri],)

    def to_yaml_tree(self, obj: SerializationNode, tag, ctx):
        return obj.data

    def from_yaml_tree(self, node, tag, ctx) -> TaggedObjectNode | TaggedListNode | TaggedScalarNode:
        if "file_date" in tag:
            converter = ctx.extension_manager.get_converter_for_type(Time)
            node = converter.from_yaml_tree(node, tag, ctx)

        # TODO: Add method for setting read_tag with some checks
        obj = NODE_CLASSES_BY_TAG[tag](node)
        obj._read_tag = tag
        return obj


class _TaggedNodeConverter(_RomanConverter):
    def __init_subclass__(cls, **kwargs) -> None:
        """
        Automatically create the converter objects.
        """
        super().__init_subclass__(**kwargs)

        if not cls.__name__.startswith("_"):
            if cls.__name__ in NODE_CONVERTERS:
                raise ValueError(f"Duplicate converter for {cls.__name__}")

            NODE_CONVERTERS[cls.__name__] = cls()

    def select_tag(self, obj, tags, ctx):
        return None

    @property
    def tags(self) -> tuple:
        return ()

    def to_yaml_tree(self, obj, tag, ctx):
        return SERIALIZATION_BY_MANIFEST[TAG_MANIFEST_REGISTRY[tag]](obj, tag)

    def from_yaml_tree(self, node, tag, ctx):
        raise NotImplementedError("Converter deserialization deferred")


[docs] class TaggedObjectNodeConverter(_TaggedNodeConverter): """ Converter for all subclasses of TaggedObjectNode. """ @property def types(self): return tuple(OBJECT_NODE_CLASSES_BY_PATTERN.values())
[docs] def to_yaml_tree(self, obj: TaggedObjectNode, tag, ctx): return super().to_yaml_tree(dict(obj._data), obj.tag, ctx)
[docs] class TaggedListNodeConverter(_TaggedNodeConverter): """ Converter for all subclasses of TaggedListNode. """ @property def types(self): return tuple(LIST_NODE_CLASSES_BY_PATTERN.values())
[docs] def to_yaml_tree(self, obj, tag, ctx): return super().to_yaml_tree(list(obj), obj.tag, ctx)
[docs] class TaggedScalarNodeConverter(_TaggedNodeConverter): """ Converter for all subclasses of TaggedScalarNode. """ @property def types(self): return list(SCALAR_NODE_CLASSES_BY_PATTERN.values())
[docs] def to_yaml_tree(self, obj, tag, ctx): node = type(obj).__bases__[0](obj) if "file_date" in obj.tag: converter = ctx.extension_manager.get_converter_for_type(type(node)) node = converter.to_yaml_tree(node, tag, ctx) return super().to_yaml_tree(node, obj.tag, ctx)