"""
Mixin classes for additional functionality for STNode classes
"""
from __future__ import annotations
import re
from copy import deepcopy
from typing import TYPE_CHECKING
from asdf.tags.core.ndarray import asdf_datatype_to_numpy_dtype
from ._schema import Builder, _get_keyword, _get_properties
from ._tagged import _get_schema_from_tag
# This is a workaround for MyPy to understand the Mixin classes
if TYPE_CHECKING:
from typing import ClassVar, TypeAlias
from astropy.time import Time
from ._tagged import TaggedObjectNode, TaggedScalarNode
class _TimeNode(Time, TaggedScalarNode):
pass
_ObjectBase: TypeAlias = TaggedObjectNode
_ScalarBase: TypeAlias = TaggedScalarNode
_TimeBase: TypeAlias = _TimeNode
else:
_ObjectBase = object
_ScalarBase: TypeAlias = object
_TimeBase: TypeAlias = object
__all__ = [
"CalibrationSoftwareNameMixin",
"FileDateMixin",
"ForcedImageSourceCatalogMixin",
"ForcedMosaicSourceCatalogMixin",
"FpsFileDateMixin",
"ImageSourceCatalogMixin",
"L2CalStepMixin",
"L3CalStepMixin",
"MosaicSourceCatalogMixin",
"MultibandSourceCatalogMixin",
"OriginMixin",
"PrdVersionMixin",
"RefFileMixin",
"SdfSoftwareVersionMixin",
"TelescopeMixin",
"TvacFileDateMixin",
"WfiModeMixin",
]
[docs]
class WfiModeMixin:
"""
Extensions to the WfiMode class.
Adds to indication properties
"""
__slots__ = ()
# Every optical element is a grating or a filter
# There are less gratings than filters so its easier to list out the
# gratings.
_GRATING_OPTICAL_ELEMENTS: ClassVar = {"GRISM", "PRISM"}
@property
def filter(self):
"""
Returns the filter if it is one, otherwise None
"""
if self.optical_element in self._GRATING_OPTICAL_ELEMENTS:
return None
else:
return self.optical_element
@property
def grating(self):
"""
Returns the grating if it is one, otherwise None
"""
if self.optical_element in self._GRATING_OPTICAL_ELEMENTS:
return self.optical_element
else:
return None
[docs]
class FileDateMixin(_TimeBase):
@classmethod
def _create_minimal(cls, defaults=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls.now()
if tag:
new._read_tag = tag
return new
@classmethod
def _create_fake_data(cls, defaults=None, shape=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls("2020-01-01T00:00:00.0", format="isot", scale="utc")
if tag:
new._read_tag = tag
return new
[docs]
class FpsFileDateMixin(FileDateMixin):
pass
[docs]
class TvacFileDateMixin(FileDateMixin):
pass
[docs]
class CalibrationSoftwareNameMixin(_ScalarBase):
@classmethod
def _create_minimal(cls, defaults=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls("RomanCAL")
if tag:
new._read_tag = tag
return new
[docs]
class PrdVersionMixin(_ScalarBase):
@classmethod
def _create_fake_data(cls, defaults=None, shape=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls("8.8.8")
if tag:
new._read_tag = tag
return new
[docs]
class SdfSoftwareVersionMixin(_ScalarBase):
@classmethod
def _create_fake_data(cls, defaults=None, shape=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls("7.7.7")
if tag:
new._read_tag = tag
return new
[docs]
class OriginMixin(_ScalarBase):
@classmethod
def _create_minimal(cls, defaults=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls("STSCI/SOC")
if tag:
new._read_tag = tag
return new
[docs]
class TelescopeMixin(_ScalarBase):
@classmethod
def _create_minimal(cls, defaults=None, builder=None, *, tag=None):
new = cls(defaults) if defaults else cls("ROMAN")
if tag:
new._read_tag = tag
return new
[docs]
class RefFileMixin(_ObjectBase):
__slots__ = ()
@classmethod
def _create_minimal(cls, defaults=None, builder=None, *, tag=None):
# copy defaults as we may modify them below
if defaults:
defaults = deepcopy(defaults)
else:
defaults = {}
schema = _get_schema_from_tag(tag or cls._default_tag)
for k, v in schema["properties"].items():
if v["type"] != "string":
continue
if k in defaults:
continue
defaults[k] = "N/A"
if not builder:
builder = Builder()
data = builder.from_object(schema, defaults)
new = cls(data)
if tag:
new._read_tag = tag
return new
[docs]
class L2CalStepMixin(_ObjectBase):
__slots__ = ()
@classmethod
def _create_minimal(cls, defaults=None, builder=None, *, tag=None):
defaults = defaults or {}
schema = _get_schema_from_tag(tag or cls._default_tag)
new = cls({k: defaults.get(k, "INCOMPLETE") for k in schema["properties"]})
if tag:
new._read_tag = tag
return new
[docs]
class L3CalStepMixin(L2CalStepMixin): # same as L2CalStepMixin
__slots__ = ()
[docs]
class ImageSourceCatalogMixin(_ObjectBase):
__slots__ = ()
[docs]
def get_column_definition(self, name):
"""
Get the definition of a named column in the catalog table.
This function parses the "definitions" part of the catalog
schema and returns the parsed content.
Parameters
----------
name: str
Column name, may contain aperture radisu or filter/band or prefixed
with ``forced_``.
Returns
-------
dict or None
Dictionary containing unit, description, and datatype information
or None if the name does not match any definition.
"""
if name.startswith("forced_"):
_, name = name.split("forced_", maxsplit=1)
definitions = _get_keyword(self.get_schema()["properties"]["source_catalog"], "definitions")
for def_name, definition in definitions.items():
if "~radius~" in def_name:
def_name = def_name.replace("~radius~", r"[0-9]{2}")
if "_~band~" in def_name:
def_name = def_name.replace("_~band~", r"(_f[0-9]{3}|)")
if "~band~" in def_name:
def_name = def_name.replace("~band~", r"(f[0-9]{3}|)")
if re.match(f"^{def_name}$", name):
return {
"unit": definition["unit"],
"description": definition["description"],
"datatype": asdf_datatype_to_numpy_dtype(
definition["properties"]["data"]["properties"]["datatype"]["enum"][0]
),
}
@classmethod
def _create_empty_catalog(cls, aperture_radii=None, filters=None):
from astropy.table import Column, Table
aperture_radii = aperture_radii or ["00"]
filters = filters or ["f184"]
table_schema = _get_schema_from_tag(cls._default_tag)["properties"]["source_catalog"]
columns = []
for raw_col_def in dict(_get_properties(table_schema))["columns"]["allOf"]:
col_def = raw_col_def["not"]["items"]["not"]
properties = dict(_get_properties(col_def))
name_regex = properties["name"]["pattern"]
unit = _get_keyword(col_def, "unit")
description = _get_keyword(col_def, "description")
dtype = asdf_datatype_to_numpy_dtype(properties["data"]["properties"]["datatype"]["enum"][0])
name_queue = [name_regex[1:-1]]
substitutions = [
(r"\[0-9]\{2}", aperture_radii),
(r"\(.*\)", filters),
]
while name_queue:
name = name_queue.pop()
for regex, values in substitutions:
if re.search(regex, name):
name_queue.extend(re.sub(regex, value, name) for value in values)
break
else:
columns.append(Column([], unit=unit, description=description, dtype=dtype, name=name))
return Table(columns)
@classmethod
def _create_fake_data(cls, defaults=None, shape=None, builder=None, *, tag=None):
defaults = defaults or {}
if "source_catalog" not in defaults:
defaults["source_catalog"] = cls._create_empty_catalog()
return super()._create_fake_data(defaults, shape, builder, tag=tag)
[docs]
class ForcedImageSourceCatalogMixin(ImageSourceCatalogMixin):
__slots__ = ()
[docs]
class MosaicSourceCatalogMixin(ImageSourceCatalogMixin):
__slots__ = ()
[docs]
class ForcedMosaicSourceCatalogMixin(ImageSourceCatalogMixin):
__slots__ = ()
[docs]
class MultibandSourceCatalogMixin(ImageSourceCatalogMixin):
__slots__ = ()