Source code for roman_datamodels._stnode._node

"""
Base node classes for all STNode classes.
    These are the base classes for the data objects used by the datamodels package.
"""

from __future__ import annotations

import datetime
from collections.abc import MutableMapping, MutableSequence
from typing import TYPE_CHECKING

import numpy as np
from asdf.lazy_nodes import AsdfDictNode, AsdfListNode
from asdf.tags.core import ndarray
from astropy.time import Time

__all__ = ["DNode", "LNode"]


def _wrap(value):
    """
    Convert dict to DNode and list to LNode
    """
    # Return objects as node classes, if applicable
    if isinstance(value, dict | AsdfDictNode):
        return DNode(value)

    if isinstance(value, list | AsdfListNode):
        return LNode(value)

    return value


def _unwrap(value):
    """
    Convert DNode to dict and LNode to list
    """
    # use "type(...) is" so that we don't unwrap subclasses
    if type(value) is DNode:
        return value._data

    if type(value) is LNode:
        return value.data

    return value


class _NodeMixin:
    """
    Mixin class to provide the common API for all Node objects
    """

    # This is a hack to avoid mypy and __slots__ inheritance issues concerning `_read_tag`
    #    ideally we would just define `_read_tag` like we did below, but mypy gets upset because
    #    __slots__ is defined so that the subclasses will be fully slotted. You can't have the
    #    same slot attributed defined in both parent classes when they are mixed together.
    if TYPE_CHECKING:
        __slots__ = ("_read_tag",)
    else:
        __slots__ = ()

    _read_tag: str | None

    def __init__(self, *args, **kwargs):
        self._read_tag = None


[docs] class DNode(MutableMapping, _NodeMixin): """ Base class describing all "object" (dict-like) data nodes for STNode classes. """ __slots__ = ("_data", "_read_tag") def __init__(self, node=None): super().__init__(node) # Handle if we are passed different data types if node is None: self._data = {} elif isinstance(node, dict | AsdfDictNode): self._data = node else: raise ValueError("Initializer only accepts dicts") def __getattr__(self, key): """ Permit accessing dict keys as attributes, assuming they are legal Python variable names. """ # Private values should have already been handled by the __getattribute__ method # bail out if we are falling back on this method if key.startswith("_"): raise AttributeError(f"No attribute {key}") # If the key is in the schema, then we can return the value if key in self._data: # Return objects as node classes, if applicable return _wrap(self._data[key]) # Raise the correct error for the attribute not being found raise AttributeError(f"No such attribute ({key}) found in node: {type(self)}") def __setattr__(self, key, value): """ Permit assigning dict keys as attributes. """ # Private keys should just be in the normal __dict__ if key[0] != "_": # Finally set the value self._data[key] = _unwrap(value) else: if key in DNode.__slots__: DNode.__dict__[key].__set__(self, value) else: raise AttributeError(f"Cannot set private attribute {key}, only allowed are {DNode.__slots__}") def __delattr__(self, name): if name in self.__slots__: super().__delattr__(name) elif name[0] != "_": self.__delitem__(name) else: raise AttributeError(f"No such attribute ({name}) found in node") def _recursive_items(self): def recurse(tree, path=None): path = path or [] # Avoid mutable default arguments if isinstance(tree, DNode | dict | AsdfDictNode): for key, val in tree.items(): yield from recurse(val, [*path, key]) elif isinstance(tree, LNode | list | tuple | AsdfListNode): for i, val in enumerate(tree): yield from recurse(val, [*path, i]) elif tree is not None: yield (".".join(str(x) for x in path), tree) yield from recurse(self)
[docs] def to_flat_dict(self, include_arrays=True, recursive=False): """ Returns a dictionary of all of the schema items as a flat dictionary. Each dictionary key is a dot-separated name. For example, the schema element ``meta.observation.date`` will end up in the dictionary as:: { "meta.observation.date": "2012-04-22T03:22:05.432" } """ def convert_val(val): """ Unwrap the tagged scalars if necessary. """ if isinstance(val, datetime.datetime): return val.isoformat() elif isinstance(val, Time): return str(val) return val item_getter = self._recursive_items if recursive else self.items if include_arrays: return {key: convert_val(val) for (key, val) in item_getter()} else: return { key: convert_val(val) for (key, val) in item_getter() if not isinstance(val, np.ndarray | ndarray.NDArrayType) }
def __asdf_traverse__(self): """Asdf traverse method for things like info/search""" return dict(self) def __len__(self): """Define length of the node""" return len(self._data) def __getitem__(self, key): """Dictionary style access data""" if key in self._data: return self._data[key] raise KeyError(f"No such key ({key}) found in node") def __setitem__(self, key, value): """Dictionary style access set data""" self._data[key] = value def __delitem__(self, key): """Dictionary style access delete data""" del self._data[key] def __dir__(self): return set(super().__dir__()) | set(self._data.keys()) def __iter__(self): """Define iteration""" return iter(self._data) def __repr__(self): """Define a representation""" return repr(self._data)
[docs] def copy(self): """Handle copying of the node""" instance = self.__class__.__new__(self.__class__) instance._read_tag = self._read_tag instance._data = self._data.copy() return instance
[docs] class LNode(MutableSequence, _NodeMixin): """ Base class describing all "array" (list-like) data nodes for STNode classes. """ __slots__ = ("_read_tag", "data") def __init__(self, node=None): super().__init__(node=node) if node is None: self.data = [] elif isinstance(node, list | AsdfListNode): self.data = node elif isinstance(node, self.__class__): self.data = node.data else: raise ValueError("Initializer only accepts lists") def __getitem__(self, index): return _wrap(self.data[index]) def __setitem__(self, index, value): self.data[index] = _unwrap(value) def __delitem__(self, index): del self.data[index] def __len__(self): return len(self.data)
[docs] def insert(self, index, value): self.data.insert(index, value)
def __asdf_traverse__(self): return list(self) def __setattr__(self, key, value): if key in LNode.__slots__: LNode.__dict__[key].__set__(self, value) else: raise AttributeError(f"Cannot set attribute {key}, only allowed are {LNode.__slots__}") def __eq__(self, other): if isinstance(other, LNode): return self.data == other.data elif isinstance(other, list | AsdfListNode): return self.data == other else: return False
[docs] def copy(self): """Handle copying of the node""" instance = self.__class__.__new__(self.__class__) instance.data = self.data.copy() instance._read_tag = self._read_tag return instance