Source code for roman_datamodels.datamodels._core

"""
This module provides the same interface as the datamodels for JWST, so that they can be
    used in a common pipeline structure. Unlike the JWST datamodels, these models are
    backed by an ASDF file and the schema structure is defined by the ASDF schema.

This provides the abstract base class ``Datamodel`` for all the specific datamodels
    used for Roman. This dataclass is intended to be subclassed to form all of the actual
    working datamodels.
"""

from __future__ import annotations

import abc
import copy
import datetime
import functools
import sys
from pathlib import Path, PurePath
from typing import TYPE_CHECKING

import asdf
import numpy as np
from asdf.exceptions import ValidationError
from asdf.tags.core.ndarray import NDArrayType
from astropy.time import Time

from roman_datamodels import stnode

if TYPE_CHECKING:
    from collections.abc import Mapping
    from typing import Any, Self

__all__ = ["MODEL_REGISTRY", "DataModel"]

MODEL_REGISTRY: dict[str, type[DataModel]] = {}


def _set_default_asdf(func):
    """
    Decorator which ensures that a DataModel has an asdf file available for use
    if required
    """

    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        if self._asdf is None:
            af = asdf.AsdfFile()
            af["roman"] = self._instance
            self._asdf = af

        return func(self, *args, **kwargs)

    return wrapper


[docs] class DataModel(abc.ABC): """Base class for all top level datamodels""" crds_observatory = "roman" _node_type: type[stnode.TaggedObjectNode] def __init_subclass__(cls, **kwargs): """Register each subclass in the MODEL_REGISTRY""" super().__init_subclass__(**kwargs) # Allow for sub-registry classes to be defined if cls.__name__.startswith("_"): return # Check the node_type is a tagged object node if not issubclass(cls._node_type, stnode.TaggedObjectNode): raise ValueError("Subclass must be a TaggedObjectNode subclass") # Check for duplicates if cls._node_type in MODEL_REGISTRY: raise ValueError(f"Duplicate model type {cls._node_type}") # Add to registry MODEL_REGISTRY[cls._node_type] = cls def __new__(cls, init=None, **kwargs): """ Handle the case where one passes in an already instantiated version of the model. In this case the constructor should just directly return the model. """ if init.__class__.__name__ == cls.__name__: return init return super().__new__(cls)
[docs] @classmethod def create_minimal(cls, defaults: Mapping[str, Any] | None = None, *, tag: str | None = None) -> Self: """ Class method that constructs an "minimal" model. The "minimal" model will contain schema-required attributes where a default value can be determined: * node class defining a default value * defined in the schema (for example single item enums) * empty container classes (for example a "meta" dict) * required items with a corresponding provided default Parameters ---------- defaults : None or dict If provided, defaults will be used in place of schema defined values for required attributes. tag: str or None If provided, specifically create a model using this tag not the default one. Returns ------- DataModel "Empty" model with optional defaults. This will often be incomplete (invalid) as not all required attributes can be guessed. """ return cls(cls._node_type.create_minimal(defaults, tag=tag))
[docs] @classmethod def create_fake_data( cls, defaults: Mapping[str, Any] | None = None, shape: tuple[int, ...] | None = None, *, tag: str | None = None ) -> Self: """ Class method that constructs a model filled with fake data. Similar to `DataModel.create_minimal` this only creates required attributes. Fake arrays will have a number of dimensions matching the schema requirements. If shape is provided only the dimensions matching the schema requirements will be used. For example if a 3 dimensional shape is provided but a fake array only requires 2 dimensions only the first 2 values from shape will be used. Parameters ---------- defaults : None or dict If provided, defaults will be used in place of schema defined or fake values for required attributes. shape : None or tuple of int When provided use this shape to determine the shape used to construct fake arrays. tag: str or None If provided, specifically create a model using this tag not the default one. Returns ------- DataModel A valid model with fake data. """ return cls(cls._node_type.create_fake_data(defaults, shape, tag=tag))
__slots__ = ("_asdf", "_files_to_close", "_instance", "_iscopy", "_shape")
[docs] @classmethod def create_from_model(cls, model: DataModel | stnode.DNode) -> Self: """ Create a new DataModel from an existing model. """ if isinstance(model, DataModel): node = model._instance else: node = model return cls(cls._node_type.create_from_node(node))
def __init__(self, init=None, **kwargs): if isinstance(init, self.__class__): # Due to __new__ above, this is already initialized. return self._iscopy = False self._shape = None self._instance = None self._asdf = None self._files_to_close = None if isinstance(init, stnode.TaggedObjectNode): if not isinstance(self, MODEL_REGISTRY.get(init.__class__)): expected = {mdl: node for node, mdl in MODEL_REGISTRY.items()}[self.__class__].__name__ raise ValidationError( f"TaggedObjectNode: {init.__class__.__name__} is not of the type expected. Expected {expected}" ) self._instance = init af = asdf.AsdfFile() af["roman"] = self._instance self._asdf = af return if init is None: self._instance = self._node_type() elif isinstance(init, str | bytes | PurePath): if isinstance(init, PurePath): init = str(init) if isinstance(init, bytes): init = init.decode(sys.getfilesystemencoding()) self._asdf = self.open_asdf(init, **kwargs) if not self.check_type(self._asdf): raise ValueError(f"ASDF file is not of the type expected. Expected {self.__class__.__name__}") self._instance = self._asdf.tree["roman"] elif isinstance(init, asdf.AsdfFile): self._asdf = init self._instance = self._asdf.tree["roman"] else: raise OSError("Argument does not appear to be an ASDF file or TaggedObjectNode.")
[docs] def check_type(self, asdf_file): """ Subclass is expected to check for proper type of node """ if "roman" not in asdf_file.tree: raise ValueError('ASDF file does not have expected "roman" attribute') return MODEL_REGISTRY[asdf_file.tree["roman"].__class__] == self.__class__
@property def _latest_manifest_uri(self): return self._node_type._latest_manifest @property def schema_uri(self): # Determine the schema corresponding to this model's tag return next( t for t in stnode.NODE_EXTENSIONS[self._latest_manifest_uri].tags if t.tag_uri == self._instance._tag ).schema_uris[0]
[docs] def close(self): if not (self._iscopy or self._asdf is None): self._asdf.close()
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __del__(self): """Ensure closure of resources when deleted.""" self.close()
[docs] def copy(self, deepcopy=True, memo=None): result = self.__class__(init=None) self.clone(result, self, deepcopy=deepcopy, memo=memo) return result
__copy__ = copy def __deepcopy__(self, memo=None): return self.copy(deepcopy=True, memo=memo)
[docs] @staticmethod def clone(target, source, deepcopy=False, memo=None): if deepcopy: target._asdf = source._asdf.copy() target._instance = copy.deepcopy(source._instance, memo=memo) else: target._asdf = source._asdf target._instance = source._instance target._iscopy = True target._files_to_close = [] target._shape = source._shape
[docs] def save(self, path, dir_path=None, *args, all_array_compression="lz4", all_array_storage="internal", **kwargs): path = Path(path(self.meta.filename) if callable(path) else path) output_path = Path(dir_path) / path.name if dir_path else path ext = path.suffix.decode(sys.getfilesystemencoding()) if isinstance(path.suffix, bytes) else path.suffix # TODO: Support gzip-compressed fits if ext == ".asdf": self.to_asdf( output_path, *args, all_array_compression=all_array_compression, all_array_storage=all_array_storage, **kwargs ) elif ext == ".parquet" and hasattr(self, "to_parquet"): self.to_parquet(output_path) else: raise ValueError(f"unknown filetype {ext}") return output_path
[docs] def open_asdf(self, init=None, **kwargs): from ._utils import _open_asdf if isinstance(init, str): return _open_asdf(init, **kwargs) return asdf.AsdfFile(init, **kwargs)
[docs] def to_asdf(self, init, *args, all_array_compression="lz4", all_array_storage="internal", **kwargs): from ._utils import temporary_update_filedate, temporary_update_filename with ( temporary_update_filename(self, Path(init).name), temporary_update_filedate(self, Time.now()), ): asdf_file = self.open_asdf(**kwargs) asdf_file["roman"] = self._instance asdf_file.write_to( init, *args, all_array_compression=all_array_compression, all_array_storage=all_array_storage, **kwargs )
[docs] def get_primary_array_name(self): """ Returns the name "primary" array for this model, which controls the size of other arrays that are implicitly created. This is intended to be overridden in the subclasses if the primary array's name is not "data". """ return "data" if hasattr(self, "data") else ""
@property def override_handle(self): """override_handle identifies in-memory models where a filepath would normally be used. """ # Arbitrary choice to look something like crds:// return f"override://{self.__class__.__name__}" @property def shape(self): if self._shape is None: primary_array_name = self.get_primary_array_name() if primary_array_name and hasattr(self, primary_array_name): primary_array = getattr(self, primary_array_name) self._shape = primary_array.shape return self._shape def __setattr__(self, attr, value): if attr.startswith("_") and attr in DataModel.__slots__: DataModel.__dict__[attr].__set__(self, value) else: setattr(self._instance, attr, value) def __getattr__(self, attr): return getattr(self._instance, attr) def __delattr__(self, attr): if attr.startswith("_") and attr in DataModel.__slots__: super().__delattr__(attr) else: delattr(self._instance, attr) def __setitem__(self, key, value): if key.startswith("_"): raise ValueError("May not specify attributes/keys that start with _") self._instance[key] = value def __getitem__(self, key): return self._instance[key] def __dir__(self): return set(super().__dir__()) | set(dir(self._instance)) def __iter__(self): return iter(self._instance)
[docs] def to_flat_dict(self, include_arrays=True): """ Returns a dictionary of all of the model items as a flat dictionary. Each dictionary key is a dot-separated name. For example, the model element ``meta.observation.date`` will end up in the dictionary as:: { "meta.observation.date": "2012-04-22T03:22:05.432" } This differs from the JWST data model in that the schema is not directly used """ def convert_val(val): if isinstance(val, datetime.datetime): return val.isoformat() elif isinstance(val, Time): return str(val) return val return { f"roman.{key}": convert_val(val) for (key, val) in self.items() if include_arrays or not isinstance(val, np.ndarray | NDArrayType) }
[docs] def items(self): """ Iterates over all of the model items in a flat way. Each element is a pair (``key``, ``value``). Each ``key`` is a dot-separated name. For example, the schema element ``meta.observation.date`` will end up in the result as:: ("meta.observation.date": "2012-04-22T03:22:05.432") Unlike the JWST DataModel implementation, this does not use schemas directly. """ yield from self._instance._recursive_items()
[docs] def get_crds_parameters(self): """ Get parameters used by CRDS to select references for this model. This will only return items under ``roman.meta``. Returns ------- dict """ return { f"roman.meta.{key}": val for key, val in self.meta.to_flat_dict(include_arrays=False, recursive=True).items() if isinstance(val, str | int | float | complex | bool) }
[docs] @_set_default_asdf def validate(self): """ Re-validate the model instance against the tags """ self._asdf.validate()
[docs] @_set_default_asdf def info(self, *args, **kwargs): return self._asdf.info(*args, **kwargs)
[docs] @_set_default_asdf def search(self, *args, **kwargs): return self._asdf.search(*args, **kwargs)
[docs] @_set_default_asdf def schema_info(self, *args, **kwargs): return self._asdf.schema_info(*args, **kwargs)