Source code for roman_datamodels.stnode._node

"""
Base node classes for all STNode classes.
    These are the base classes for the data objects used by the datamodels package.
"""

import datetime
from collections import UserList
from collections.abc import MutableMapping

import numpy as np
from asdf.lazy_nodes import AsdfDictNode, AsdfListNode
from asdf.tags.core import ndarray
from astropy.time import Time

from ._registry import SCALAR_NODE_CLASSES_BY_KEY

__all__ = ["DNode", "LNode"]


[docs] class DNode(MutableMapping): """ Base class describing all "object" (dict-like) data nodes for STNode classes. """ _pattern = None _latest_manifest = None def __init__(self, node=None, parent=None, name=None): # Handle if we are passed different data types if node is None: self.__dict__["_data"] = {} elif isinstance(node, dict | AsdfDictNode): self.__dict__["_data"] = node else: raise ValueError("Initializer only accepts dicts") # Set the metadata tracked by the node self._parent = parent self._name = name @staticmethod def _convert_to_scalar(key, value, ref=None): """Find and wrap scalars in the appropriate class, if its a tagged one.""" from ._tagged import TaggedScalarNode if isinstance(ref, TaggedScalarNode): # we want the exact class (not possible subclasses) if type(value) == type(ref): # noqa: E721 return value return type(ref)(value) if isinstance(value, TaggedScalarNode): return value if key in SCALAR_NODE_CLASSES_BY_KEY: value = SCALAR_NODE_CLASSES_BY_KEY[key](value) return value def __getattr__(self, key): """ Permit accessing dict keys as attributes, assuming they are legal Python variable names. """ # Private values should have already been handled by the __getattribute__ method # bail out if we are falling back on this method if key.startswith("_"): raise AttributeError(f"No attribute {key}") # If the key is in the schema, then we can return the value if key in self._data: # Cast the value into the appropriate tagged scalar class value = self._convert_to_scalar(key, self._data[key]) # Return objects as node classes, if applicable if isinstance(value, dict | AsdfDictNode): return DNode(value, parent=self, name=key) elif isinstance(value, list | AsdfListNode): return LNode(value) else: return value # Raise the correct error for the attribute not being found raise AttributeError(f"No such attribute ({key}) found in node") def __setattr__(self, key, value): """ Permit assigning dict keys as attributes. """ # Private keys should just be in the normal __dict__ if key[0] != "_": # Wrap things in the tagged scalar classes if necessary value = self._convert_to_scalar(key, value, self._data.get(key)) # Finally set the value self._data[key] = value else: self.__dict__[key] = value def _recursive_items(self): def recurse(tree, path=None): path = path or [] # Avoid mutable default arguments if isinstance(tree, DNode | dict | AsdfDictNode): for key, val in tree.items(): yield from recurse(val, [*path, key]) elif isinstance(tree, LNode | list | tuple | AsdfListNode): for i, val in enumerate(tree): yield from recurse(val, [*path, i]) elif tree is not None: yield (".".join(str(x) for x in path), tree) yield from recurse(self)
[docs] def to_flat_dict(self, include_arrays=True, recursive=False): """ Returns a dictionary of all of the schema items as a flat dictionary. Each dictionary key is a dot-separated name. For example, the schema element ``meta.observation.date`` will end up in the dictionary as:: { "meta.observation.date": "2012-04-22T03:22:05.432" } """ def convert_val(val): """ Unwrap the tagged scalars if necessary. """ if isinstance(val, datetime.datetime): return val.isoformat() elif isinstance(val, Time): return str(val) return val item_getter = self._recursive_items if recursive else self.items if include_arrays: return {key: convert_val(val) for (key, val) in item_getter()} else: return { key: convert_val(val) for (key, val) in item_getter() if not isinstance(val, np.ndarray | ndarray.NDArrayType) }
def __asdf_traverse__(self): """Asdf traverse method for things like info/search""" return dict(self) def __len__(self): """Define length of the node""" return len(self._data) def __getitem__(self, key): """Dictionary style access data""" if key in self._data: return self._data[key] raise KeyError(f"No such key ({key}) found in node") def __setitem__(self, key, value): """Dictionary style access set data""" # Convert the value to a tagged scalar if necessary value = self._convert_to_scalar(key, value, self._data.get(key)) # If the value is a dictionary, loop over its keys and convert them to tagged scalars if isinstance(value, dict | AsdfDictNode): for sub_key, sub_value in value.items(): value[sub_key] = self._convert_to_scalar(sub_key, sub_value) self._data[key] = value def __delitem__(self, key): """Dictionary style access delete data""" del self._data[key] def __dir__(self): return set(super().__dir__()) | set(self._data.keys()) def __iter__(self): """Define iteration""" return iter(self._data) def __repr__(self): """Define a representation""" return repr(self._data)
[docs] def copy(self): """Handle copying of the node""" instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__.copy()) instance.__dict__["_data"] = self.__dict__["_data"].copy() return instance
[docs] class LNode(UserList): """ Base class describing all "array" (list-like) data nodes for STNode classes. """ _pattern = None _latest_manifest = None def __init__(self, node=None): if node is None: self.data = [] elif isinstance(node, list | AsdfListNode): self.data = node elif isinstance(node, self.__class__): self.data = node.data else: raise ValueError("Initializer only accepts lists") def __getitem__(self, index): value = self.data[index] if isinstance(value, dict | AsdfDictNode): return DNode(value) elif isinstance(value, list | AsdfListNode): return LNode(value) else: return value def __asdf_traverse__(self): return list(self)