Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-22802: Changes requested by the DRAGONS team #76

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion python/lsst/pex/config/callStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class StackFrame:
getStackFrame
"""

_STRIP = "/python/lsst/"
_STRIP = "/DRAGONS/"
"""String to strip from the ``filename`` in the constructor."""

def __init__(self, filename, lineno, function, content=None):
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/pex/config/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def compareScalars(name, v1, v2, output, rtol=1e-8, atol=1e-8, dtype=None):
-----
Floating point comparisons are performed by `numpy.allclose`.
"""
if isinstance(dtype, tuple):
dtype = type(v1)
if v1 is None or v2 is None:
result = v1 == v2
elif dtype in (float, complex):
Expand Down
110 changes: 83 additions & 27 deletions python/lsst/pex/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import copy
import importlib
import io
import itertools
import math
import os
import re
Expand All @@ -56,6 +57,14 @@
except ImportError:
yaml = None

try:
from astrodata import AstroData
except ImportError:

class AstroData:
pass


from .callStack import getCallStack, getStackFrame
from .comparison import compareConfigs, compareScalars, getComparisonName

Expand Down Expand Up @@ -135,8 +144,17 @@ def _autocast(x, dtype):
``dtype``. If the cast cannot be performed the original value of
``x`` is returned.
"""
if dtype is float and isinstance(x, int):
if isinstance(x, int) and (
dtype is float or (isinstance(dtype, tuple) and float in dtype and int not in dtype)
):
return float(x)
if isinstance(x, str):
for type in (int, float, bool):
if dtype == type or (isinstance(dtype, tuple) and type in dtype):
try:
return type(x)
except ValueError: # Carry on and try a different coercion
pass
return x


Expand Down Expand Up @@ -228,9 +246,9 @@ def getFields(classtype):
for b in bases:
fields.update(getFields(b))

for k, v in classtype.__dict__.items():
if isinstance(v, Field):
fields[k] = v
field_dict = {k: v for k, v in classtype.__dict__.items() if isinstance(v, Field)}
for k, v in sorted(field_dict.items(), key=lambda x: x[1]._creation_order):
fields[k] = v
return fields

fields = getFields(cls)
Expand Down Expand Up @@ -288,7 +306,7 @@ def __init__(self, field, config, msg):

self.configSource = config._source
error = (
f"{self.fieldType.__name__} '{self.fullname}' failed validation: {msg}\n"
f"{self.fieldType.__name__} '{self.fullname}' ({field.doc}) failed validation: {msg}\n"
f"For more information see the Field definition at:\n{self.fieldSource.format()}"
f" and the Config definition at:\n{self.configSource.format()}"
)
Expand Down Expand Up @@ -396,10 +414,12 @@ class Field(Generic[FieldTypeVar]):
Class.
"""

supportedTypes = {str, bool, float, int, complex}
supportedTypes = {str, bool, float, int, complex, tuple, AstroData}
"""Supported data types for field values (`set` of types).
"""

_counter = itertools.count()

@staticmethod
def _parseTypingArgs(
params: tuple[type, ...] | tuple[str, ...], kwds: Mapping[str, Any]
Expand Down Expand Up @@ -463,7 +483,12 @@ def __init__(self, doc, dtype=None, default=None, check=None, optional=False, de
raise ValueError(
"dtype must either be supplied as an argument or as a type argument to the class"
)
if dtype not in self.supportedTypes:
if isinstance(dtype, list):
dtype = tuple(dtype)
if isinstance(dtype, tuple):
if any(x not in self.supportedTypes for x in dtype):
raise ValueError(f"Unsupported Field dtype in {_typeStr(dtype)}")
elif dtype not in self.supportedTypes:
raise ValueError(f"Unsupported Field dtype {_typeStr(dtype)}")

source = getStackFrame()
Expand Down Expand Up @@ -522,6 +547,8 @@ def _setup(self, doc, dtype, default, check, optional, source, deprecated):
`~lsst.pex.config.callStack.StackFrame`).
"""

self._creation_order = next(Field._counter)

def rename(self, instance):
r"""Rename the field in a `~lsst.pex.config.Config` (for internal use
only).
Expand Down Expand Up @@ -609,9 +636,16 @@ def _validateValue(self, value):
return

if not isinstance(value, self.dtype):
msg = (
f"Value {value} is of incorrect type {_typeStr(value)}. Expected type {_typeStr(self.dtype)}"
)
if isinstance(self.dtype, tuple):
msg = (
f"Value {value} is of incorrect type {_typeStr(value)}. "
f"Expected types {[_typeStr(dt) for dt in self.dtype]}"
)
else:
msg = (
f"Value {value} is of incorrect type {_typeStr(value)}. "
f"Expected type {_typeStr(self.dtype)}"
)
raise TypeError(msg)
if self.check is not None and not self.check(value):
msg = f"Value {value} is not a valid value"
Expand Down Expand Up @@ -778,6 +812,12 @@ def __set__(
if instance._frozen:
raise FieldValidationError(self, instance, "Cannot modify a frozen Config")

if at is None:
at = getCallStack()
# setDefaults() gets a free pass due to our mashing of inheritance
if self.name not in instance._fields:
raise AttributeError(f"{instance.__class__.__name__} has no attribute {self.name}")

history = instance._history.setdefault(self.name, [])
if value is not None:
value = _autocast(value, self.dtype)
Expand All @@ -787,9 +827,9 @@ def __set__(
raise FieldValidationError(self, instance, str(e))

instance._storage[self.name] = value
if at is None:
at = getCallStack()
history.append((value, at, label))
# We don't want to put an actual AD object here, so just the filename
value_to_append = value.filename if isinstance(value, AstroData) else value
history.append((value_to_append, at, label))

def __delete__(self, instance, at=None, label="deletion"):
"""Delete an attribute from a `lsst.pex.config.Config` instance.
Expand Down Expand Up @@ -962,6 +1002,9 @@ class behavior.
_history: dict[str, list[Any]]
_imports: set[Any]

# Only _fields are exposure. _storage retains items that have been
# deleted.

def __iter__(self):
"""Iterate over fields."""
return self._fields.__iter__()
Expand All @@ -974,7 +1017,7 @@ def keys(self):
names : `~collections.abc.KeysView`
List of `lsst.pex.config.Field` names.
"""
return self._storage.keys()
return list(self._fields)

def values(self):
"""Get field values.
Expand All @@ -984,7 +1027,7 @@ def values(self):
values : `~collections.abc.ValuesView`
Iterator of field values.
"""
return self._storage.values()
return self.toDict().values()

def items(self):
"""Get configurations as ``(field name, field value)`` pairs.
Expand All @@ -997,7 +1040,11 @@ def items(self):
0. Field name.
1. Field value.
"""
return self._storage.items()
return self.toDict().items()

def doc(self, field):
"""Return docstring for field."""
return self._fields[field].doc

def __contains__(self, name):
"""Return `True` if the specified field exists in this config.
Expand All @@ -1012,7 +1059,7 @@ def __contains__(self, name):
in : `bool`
`True` if the specified field exists in the config.
"""
return self._storage.__contains__(name)
return self._storage.__contains__(name) and name in self._fields

def __new__(cls, *args, **kw):
"""Allocate a new `lsst.pex.config.Config` object.
Expand All @@ -1038,9 +1085,7 @@ def __new__(cls, *args, **kw):
instance._history = {}
instance._imports = set()
# load up defaults
for field in instance._fields.values():
instance._history[field.name] = []
field.__set__(instance, field.default, at=at + [field.source], label="default")
instance.reset(at=at)
# set custom default-overrides
instance.setDefaults()
# set constructor overrides
Expand All @@ -1060,6 +1105,14 @@ def __reduce__(self):
self.saveToStream(stream)
return (unreduceConfig, (self.__class__, stream.getvalue().encode()))

def reset(self, at=None):
"""Reset all values to their defaults."""
if at is None:
at = getCallStack()
for field in self._fields.values():
self._history[field.name] = []
field.__set__(self, field.default, at=at + [field.source], label="default")

def setDefaults(self):
"""Subclass hook for computing defaults.

Expand Down Expand Up @@ -1131,7 +1184,9 @@ def update(self, **kw):
field = self._fields[name]
field.__set__(self, value, at=at, label=label)
except KeyError:
raise KeyError(f"No field of name {name} exists in config type {_typeStr(self)}")
raise KeyError(
"{} has no field named {}".format(type(self).__name__.replace("Config", ""), name)
)

def load(self, filename, root="config"):
"""Modify this config in place by executing the Python code in a
Expand Down Expand Up @@ -1514,7 +1569,7 @@ def validate(self):
for field in self._fields.values():
field.validate(self)

def formatHistory(self, name, **kwargs):
def formatHistory(self, name=None, **kwargs):
"""Format a configuration field's history to a human-readable format.

Parameters
Expand All @@ -1533,7 +1588,7 @@ def formatHistory(self, name, **kwargs):
--------
lsst.pex.config.history.format
"""
import lsst.pex.config.history as pexHist
from . import history as pexHist

return pexHist.format(self, name, **kwargs)

Expand Down Expand Up @@ -1577,11 +1632,12 @@ def __setattr__(self, attr, value, at=None, label="assignment"):
raise AttributeError(f"{_typeStr(self)} has no attribute {attr}")

def __delattr__(self, attr, at=None, label="deletion"):
# CJS: Hacked to allow setDefaults() to delete non-existent fields
if at is None:
at = getCallStack()
if attr in self._fields:
if at is None:
at = getCallStack()
self._fields[attr].__delete__(self, at=at, label=label)
else:
del self._fields[attr]
elif not any(stk.function == "setDefaults" for stk in at):
object.__delattr__(self, attr)

def __eq__(self, other):
Expand Down
16 changes: 12 additions & 4 deletions python/lsst/pex/config/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _colorize(text, category):
return str(text)


def format(config, name=None, writeSourceLine=True, prefix="", verbose=False):
def format(config, name=None, writeSourceLine=True, prefix="", verbose=False, debug=False):
"""Format the history record for a configuration, or a specific
configuration field.

Expand All @@ -189,12 +189,17 @@ def format(config, name=None, writeSourceLine=True, prefix="", verbose=False):
even before any source line. The default is an empty string.
verbose : `bool`, optional
Default is `False`.
debug : `bool`, optional
Enable debug detail.
"""
msg = []
verbose |= debug # verbose=False and debug=True seems wrong!
if name is None:
for i, name in enumerate(config.history.keys()):
if i > 0:
print()
print(format(config, name))
msg.append("")
msg.append(format(config, name))
return "\n".join(msg)

outputs = []
for value, stack, label in config.history.get(name, []):
Expand All @@ -207,7 +212,7 @@ def format(config, name=None, writeSourceLine=True, prefix="", verbose=False):
"execfile",
"wrapper",
) or os.path.split(frame.filename)[1] in ("argparse.py", "argumentParser.py"):
if not verbose:
if not debug:
continue

line = []
Expand Down Expand Up @@ -235,6 +240,9 @@ def format(config, name=None, writeSourceLine=True, prefix="", verbose=False):

output.append(line)

if not verbose:
break

outputs.append([value, output])

if outputs:
Expand Down
Loading
Loading