diff --git a/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py b/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py new file mode 100644 index 000000000..7c94d89de --- /dev/null +++ b/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName + + +class CreateArtificialNormalizationRequest(BaseModel): + runNumber: str + useLiteMode: bool + peakWindowClippingSize: int + smoothingParameter: float + decreaseParameter: bool = True + lss: bool = True + diffractionWorkspace: WorkspaceName + + class Config: + arbitrary_types_allowed = True # Allow arbitrary types like WorkspaceName + extra = "forbid" # Forbid extra fields diff --git a/src/snapred/backend/dao/request/ReductionRequest.py b/src/snapred/backend/dao/request/ReductionRequest.py index dd2c1099d..cb82e272b 100644 --- a/src/snapred/backend/dao/request/ReductionRequest.py +++ b/src/snapred/backend/dao/request/ReductionRequest.py @@ -23,6 +23,7 @@ class ReductionRequest(BaseModel): versions: Versions = Versions(None, None) pixelMasks: List[WorkspaceName] = [] + artificialNormalization: Optional[WorkspaceName] = None # TODO: Move to SNAPRequest continueFlags: Optional[ContinueWarning.Type] = ContinueWarning.Type.UNSET diff --git a/src/snapred/backend/dao/response/ArtificialNormResponse.py b/src/snapred/backend/dao/response/ArtificialNormResponse.py new file mode 100644 index 000000000..78c0f6612 --- /dev/null +++ b/src/snapred/backend/dao/response/ArtificialNormResponse.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, ConfigDict + +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName + + +class ArtificialNormResponse(BaseModel): + diffractionWorkspace: WorkspaceName + + model_config = ConfigDict( + extra="forbid", + # required in order to use 'WorkspaceName' + arbitrary_types_allowed=True, + ) diff --git a/src/snapred/backend/error/StateValidationException.py b/src/snapred/backend/error/StateValidationException.py index ba1b6f28e..3e3d13097 100644 --- a/src/snapred/backend/error/StateValidationException.py +++ b/src/snapred/backend/error/StateValidationException.py @@ -11,9 +11,14 @@ class StateValidationException(Exception): "Raised when an Instrument State is invalid" - def __init__(self, exception: Exception): - exceptionStr = str(exception) - tb = exception.__traceback__ + def __init__(self, exception): + # Handle both string and Exception types for 'exception' + if isinstance(exception, Exception): + exceptionStr = str(exception) + tb = exception.__traceback__ + else: + exceptionStr = str(exception) + tb = None if tb is not None: tb_info = traceback.extract_tb(tb) @@ -24,7 +29,7 @@ def __init__(self, exception: Exception): else: filePath, lineNumber, functionName = None, lineNumber, functionName else: - filePath, lineNumber, functionName = None, None, None + filePath, lineNumber, functionName = None, None, None # noqa: F841 doesFileExist, hasWritePermission = self._checkFileAndPermissions(filePath) diff --git a/src/snapred/backend/recipe/GenericRecipe.py b/src/snapred/backend/recipe/GenericRecipe.py index 437ba3ebf..22284c4ce 100644 --- a/src/snapred/backend/recipe/GenericRecipe.py +++ b/src/snapred/backend/recipe/GenericRecipe.py @@ -7,6 +7,7 @@ from snapred.backend.log.logger import snapredLogger from snapred.backend.recipe.algorithm.BufferMissingColumnsAlgo import BufferMissingColumnsAlgo from snapred.backend.recipe.algorithm.CalibrationMetricExtractionAlgorithm import CalibrationMetricExtractionAlgorithm +from snapred.backend.recipe.algorithm.CreateArtificialNormalizationAlgo import CreateArtificialNormalizationAlgo from snapred.backend.recipe.algorithm.DetectorPeakPredictor import DetectorPeakPredictor from snapred.backend.recipe.algorithm.FitMultiplePeaksAlgorithm import FitMultiplePeaksAlgorithm from snapred.backend.recipe.algorithm.FocusSpectraAlgorithm import FocusSpectraAlgorithm @@ -104,3 +105,7 @@ class ConvertTableToMatrixWorkspaceRecipe(GenericRecipe[ConvertTableToMatrixWork class BufferMissingColumnsRecipe(GenericRecipe[BufferMissingColumnsAlgo]): pass + + +class ArtificialNormalizationRecipe(GenericRecipe[CreateArtificialNormalizationAlgo]): + pass diff --git a/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py b/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py index 3d6263f14..f4a428314 100644 --- a/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py +++ b/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py @@ -36,7 +36,7 @@ def PyInit(self): "OutputWorkspace", "", Direction.Output, - PropertyMode.Mandatory, + PropertyMode.Optional, validator=WorkspaceUnitValidator("dSpacing"), ), doc="Workspace that contains artificial normalization.", @@ -58,7 +58,9 @@ def chopInredients(self, ingredientsStr: str): def unbagGroceries(self): self.inputWorkspaceName = self.getPropertyValue("InputWorkspace") - self.outputWorkspaceName = self.getPropertyValue("OutputWorkspace") + self.outputWorkspaceName = ( + self.getPropertyValue("OutputWorkspace") or self.inputWorkspaceName + "_artificial_norm" + ) def peakClip(self, data, winSize: int, decrese: bool, LLS: bool, smoothing: float): # Clipping peaks from the data with optional smoothing and transformations @@ -135,6 +137,7 @@ def PyExec(self): smoothing=self.smoothingParameter, ) self.outputWorkspace.setY(i, clippedData) + self.outputWorkspace.setDistribution(True) # Set the output workspace property self.setProperty("OutputWorkspace", self.outputWorkspaceName) diff --git a/src/snapred/backend/service/ReductionService.py b/src/snapred/backend/service/ReductionService.py index 84d09cdbf..13f7856d1 100644 --- a/src/snapred/backend/service/ReductionService.py +++ b/src/snapred/backend/service/ReductionService.py @@ -3,14 +3,20 @@ from pathlib import Path from typing import Any, Dict, List -from snapred.backend.dao.ingredients import GroceryListItem, ReductionIngredients +from snapred.backend.dao.ingredients import ( + ArtificialNormalizationIngredients, + GroceryListItem, + ReductionIngredients, +) from snapred.backend.dao.reduction.ReductionRecord import ReductionRecord from snapred.backend.dao.request import ( + CreateArtificialNormalizationRequest, FarmFreshIngredients, ReductionExportRequest, ReductionRequest, ) from snapred.backend.dao.request.ReductionRequest import Versions +from snapred.backend.dao.response.ArtificialNormResponse import ArtificialNormResponse from snapred.backend.dao.response.ReductionResponse import ReductionResponse from snapred.backend.dao.SNAPRequest import SNAPRequest from snapred.backend.data.DataExportService import DataExportService @@ -20,6 +26,9 @@ from snapred.backend.error.StateValidationException import StateValidationException from snapred.backend.log.logger import snapredLogger from snapred.backend.recipe.algorithm.MantidSnapper import MantidSnapper +from snapred.backend.recipe.GenericRecipe import ( + ArtificialNormalizationRecipe, +) from snapred.backend.recipe.ReductionRecipe import ReductionRecipe from snapred.backend.service.Service import Service from snapred.backend.service.SousChef import SousChef @@ -72,6 +81,7 @@ def __init__(self): self.registerPath("checkWritePermissions", self.checkWritePermissions) self.registerPath("getSavePath", self.getSavePath) self.registerPath("getStateIds", self.getStateIds) + self.registerPath("artificialNormalization", self.artificialNormalization) return @staticmethod @@ -85,42 +95,67 @@ def validateReduction(self, request: ReductionRequest): :param request: a reduction request :type request: ReductionRequest """ - continueFlags = ContinueWarning.Type.UNSET - # check if a normalization is present - if not self.dataFactoryService.normalizationExists(request.runNumber, request.useLiteMode): - continueFlags |= ContinueWarning.Type.MISSING_NORMALIZATION - # check if a diffraction calibration is present - if not self.dataFactoryService.calibrationExists(request.runNumber, request.useLiteMode): - continueFlags |= ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION - - # remove any continue flags that are present in the request by xor-ing with the flags - if request.continueFlags: - continueFlags = continueFlags ^ (request.continueFlags & continueFlags) - - if continueFlags: - raise ContinueWarning( - "Reduction is missing calibration data, continue in uncalibrated mode?", continueFlags - ) + if request.artificialNormalization is not None: + continueFlags = ContinueWarning.Type.UNSET + + # check that the user has write permissions to the save directory + if not self.checkWritePermissions(request.runNumber): + continueFlags |= ContinueWarning.Type.NO_WRITE_PERMISSIONS + + # remove any continue flags that are present in the request by xor-ing with the flags + if request.continueFlags: + continueFlags = continueFlags ^ (request.continueFlags & continueFlags) + + if continueFlags: + raise ContinueWarning( + f"

Remeber, you don't have permissions to write to " + f"
{self.getSavePath(request.runNumber)},
" + + "but you can still save using the workbench tools.

" + + "

Would you like to continue anyway?

", + continueFlags, + ) + else: + continueFlags = ContinueWarning.Type.UNSET + warningMessages = [] + + # check if a normalization is present + if not self.dataFactoryService.normalizationExists(request.runNumber, request.useLiteMode): + continueFlags |= ContinueWarning.Type.MISSING_NORMALIZATION + warningMessages.append("Normalization is missing, continuing with artificial normalization step.") + # check if a diffraction calibration is present + if not self.dataFactoryService.calibrationExists(request.runNumber, request.useLiteMode): + continueFlags |= ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION + warningMessages.append("Diffraction calibration is missing, continuing with uncalibrated mode.") + + # remove any continue flags that are present in the request by xor-ing with the flags + if request.continueFlags: + continueFlags = continueFlags ^ (request.continueFlags & continueFlags) + + if continueFlags: + detailedMessage = "\n".join(warningMessages) + raise ContinueWarning( + f"The reduction cannot proceed due to missing data:\n{detailedMessage}\n", continueFlags + ) - # ... ensure separate continue warnings ... - continueFlags = ContinueWarning.Type.UNSET + # ... ensure separate continue warnings ... + continueFlags = ContinueWarning.Type.UNSET - # check that the user has write permissions to the save directory - if not self.checkWritePermissions(request.runNumber): - continueFlags |= ContinueWarning.Type.NO_WRITE_PERMISSIONS + # check that the user has write permissions to the save directory + if not self.checkWritePermissions(request.runNumber): + continueFlags |= ContinueWarning.Type.NO_WRITE_PERMISSIONS - # remove any continue flags that are present in the request by xor-ing with the flags - if request.continueFlags: - continueFlags = continueFlags ^ (request.continueFlags & continueFlags) + # remove any continue flags that are present in the request by xor-ing with the flags + if request.continueFlags: + continueFlags = continueFlags ^ (request.continueFlags & continueFlags) - if continueFlags: - raise ContinueWarning( - f"

It looks like you don't have permissions to write to " - f"
{self.getSavePath(request.runNumber)},
" - + "but you can still save using the workbench tools.

" - + "

Would you like to continue anyway?

", - continueFlags, - ) + if continueFlags: + raise ContinueWarning( + f"

It looks like you don't have permissions to write to " + f"
{self.getSavePath(request.runNumber)},
" + + "but you can still save using the workbench tools.

" + + "

Would you like to continue anyway?

", + continueFlags, + ) @FromString def reduction(self, request: ReductionRequest): @@ -137,6 +172,9 @@ def reduction(self, request: ReductionRequest): ingredients = self.prepReductionIngredients(request) groceries = self.fetchReductionGroceries(request) + if isinstance(groceries, ArtificialNormResponse): + return groceries + # attach the list of grouping workspaces to the grocery dictionary groceries["groupingWorkspaces"] = groupingResults["groupingWorkspaces"] @@ -308,8 +346,9 @@ def fetchReductionGroceries(self, request: ReductionRequest) -> Dict[str, Any]: :rtype: Dict[str, Any] """ # Fetch pixel masks - residentMasks = {} combinedMask = None + residentMasks = {} + # Check for existing pixel masks if request.pixelMasks: for mask in request.pixelMasks: match mask.tokens("workspaceType"): @@ -341,30 +380,61 @@ def fetchReductionGroceries(self, request: ReductionRequest) -> Dict[str, Any]: # As an interim solution: set the request "versions" field to the latest calibration and normalization versions. # TODO: set these when the request is initially generated. - calVersion = None - normVersion = None - calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion(request.runNumber, request.useLiteMode) - self.groceryClerk.name("diffcalWorkspace").diffcal_table(request.runNumber, calVersion).useLiteMode( - request.useLiteMode - ).add() - - if ContinueWarning.Type.MISSING_NORMALIZATION not in request.continueFlags: - normVersion = self.dataFactoryService.getThisOrLatestNormalizationVersion( + if request.artificialNormalization is not None: + calVersion = None + normVersion = 0 + calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion( request.runNumber, request.useLiteMode ) - self.groceryClerk.name("normalizationWorkspace").normalization(request.runNumber, normVersion).useLiteMode( + self.groceryClerk.name("diffcalWorkspace").diffcal_table(request.runNumber, calVersion).useLiteMode( request.useLiteMode ).add() + return self.groceryService.fetchGroceryDict( + groceryDict=self.groceryClerk.buildDict(), + normalizationWorkspace=request.artificialNormalization, + **({"maskWorkspace": combinedMask} if combinedMask else {}), + ) - request.versions = Versions( - calVersion, - normVersion, - ) + else: + calVersion = None + normVersion = None + calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion( + request.runNumber, request.useLiteMode + ) + self.groceryClerk.name("diffcalWorkspace").diffcal_table(request.runNumber, calVersion).useLiteMode( + request.useLiteMode + ).add() - return self.groceryService.fetchGroceryDict( - groceryDict=self.groceryClerk.buildDict(), - **({"maskWorkspace": combinedMask} if combinedMask else {}), - ) + if ContinueWarning.Type.MISSING_NORMALIZATION not in request.continueFlags: + normVersion = self.dataFactoryService.getThisOrLatestNormalizationVersion( + request.runNumber, request.useLiteMode + ) + self.groceryClerk.name("normalizationWorkspace").normalization( + request.runNumber, normVersion + ).useLiteMode(request.useLiteMode).add() + elif calVersion and normVersion is None: + groceryList = ( + self.groceryClerk.name("diffractionWorkspace") + .diffcal_output(request.runNumber, calVersion) + .useLiteMode(request.useLiteMode) + .unit(wng.Units.DSP) + .group("column") + .buildDict() + ) + + groceries = self.groceryService.fetchGroceryDict(groceryList) + diffractionWorkspace = groceries.get("diffractionWorkspace") + return ArtificialNormResponse(diffractionWorkspace=diffractionWorkspace) + + request.versions = Versions( + calVersion, + normVersion, + ) + + return self.groceryService.fetchGroceryDict( + groceryDict=self.groceryClerk.buildDict(), + **({"maskWorkspace": combinedMask} if combinedMask else {}), + ) def saveReduction(self, request: ReductionExportRequest): self.dataExportService.exportReductionRecord(request.record) @@ -424,3 +494,16 @@ def _groupByVanadiumVersion(self, requests: List[SNAPRequest]): def getCompatibleMasks(self, request: ReductionRequest) -> List[WorkspaceName]: runNumber, useLiteMode = request.runNumber, request.useLiteMode return self.dataFactoryService.getCompatibleReductionMasks(runNumber, useLiteMode) + + def artificialNormalization(self, request: CreateArtificialNormalizationRequest): + ingredients = ArtificialNormalizationIngredients( + peakWindowClippingSize=request.peakWindowClippingSize, + smoothingParameter=request.smoothingParameter, + decreaseParameter=request.decreaseParameter, + lss=request.lss, + ) + artificialNormWorkspace = ArtificialNormalizationRecipe().executeRecipe( + InputWorkspace=request.diffractionWorkspace, + Ingredients=ingredients, + ) + return artificialNormWorkspace diff --git a/src/snapred/ui/view/BackendRequestView.py b/src/snapred/ui/view/BackendRequestView.py index 6def779e2..1644b17f0 100644 --- a/src/snapred/ui/view/BackendRequestView.py +++ b/src/snapred/ui/view/BackendRequestView.py @@ -6,6 +6,7 @@ from snapred.ui.widget.LabeledField import LabeledField from snapred.ui.widget.MultiSelectDropDown import MultiSelectDropDown from snapred.ui.widget.SampleDropDown import SampleDropDown +from snapred.ui.widget.TrueFalseDropDown import TrueFalseDropDown class BackendRequestView(QWidget): @@ -33,6 +34,9 @@ def _labeledCheckBox(self, label): def _sampleDropDown(self, label, items=[]): return SampleDropDown(label, items, self) + def _trueFalseDropDown(self, label): + return TrueFalseDropDown(label, self) + def _multiSelectDropDown(self, label, items=[]): return MultiSelectDropDown(label, items, self) diff --git a/src/snapred/ui/view/reduction/ArtificialNormalizationView.py b/src/snapred/ui/view/reduction/ArtificialNormalizationView.py new file mode 100644 index 000000000..7f5ad4ff3 --- /dev/null +++ b/src/snapred/ui/view/reduction/ArtificialNormalizationView.py @@ -0,0 +1,172 @@ +import matplotlib.pyplot as plt +from mantid.plots.datafunctions import get_spectrum +from mantid.simpleapi import mtd +from qtpy.QtCore import Signal, Slot +from qtpy.QtWidgets import ( + QHBoxLayout, + QLineEdit, + QMessageBox, + QPushButton, +) +from snapred.meta.Config import Config +from snapred.meta.decorators.Resettable import Resettable +from snapred.ui.view.BackendRequestView import BackendRequestView +from snapred.ui.widget.SmoothingSlider import SmoothingSlider +from workbench.plotting.figuremanager import MantidFigureCanvas +from workbench.plotting.toolbar import WorkbenchNavigationToolbar + + +@Resettable +class ArtificialNormalizationView(BackendRequestView): + signalRunNumberUpdate = Signal(str) + signalValueChanged = Signal(float, bool, bool, int) + signalUpdateRecalculationButton = Signal(bool) + signalUpdateFields = Signal(float, bool, bool) + + def __init__(self, parent=None): + super().__init__(parent=parent) + + # create the run number fields + self.fieldRunNumber = self._labeledField("Run Number", QLineEdit()) + + # create the graph elements + self.figure = plt.figure(constrained_layout=True) + self.canvas = MantidFigureCanvas(self.figure) + self.navigationBar = WorkbenchNavigationToolbar(self.canvas, self) + + # create the other specification elements + self.lssDropdown = self._trueFalseDropDown("LSS") + self.decreaseParameterDropdown = self._trueFalseDropDown("Decrease Parameter") + + # disable run number + for x in [self.fieldRunNumber]: + x.setEnabled(False) + + # create the adjustment controls + self.smoothingSlider = self._labeledField("Smoothing", SmoothingSlider()) + self.peakWindowClippingSize = self._labeledField( + "Peak Window Clipping Size", + QLineEdit(str(Config["constants.ArtificialNormalization.peakWindowClippingSize"])), + ) + + peakControlLayout = QHBoxLayout() + peakControlLayout.addWidget(self.smoothingSlider, 2) + peakControlLayout.addWidget(self.peakWindowClippingSize) + + # a big ol recalculate button + self.recalculationButton = QPushButton("Recalculate") + self.recalculationButton.clicked.connect(self.emitValueChange) + + # add all elements to the grid layout + self.layout.addWidget(self.fieldRunNumber, 0, 0) + self.layout.addWidget(self.navigationBar, 1, 0) + self.layout.addWidget(self.canvas, 2, 0, 1, -1) + self.layout.addLayout(peakControlLayout, 3, 0, 1, 2) + self.layout.addWidget(self.lssDropdown, 4, 0) + self.layout.addWidget(self.decreaseParameterDropdown, 4, 1) + self.layout.addWidget(self.recalculationButton, 5, 0, 1, 2) + + self.layout.setRowStretch(2, 10) + + # store the initial layout without graphs + self.initialLayoutHeight = self.size().height() + + self.signalUpdateRecalculationButton.connect(self.setEnableRecalculateButton) + self.signalUpdateFields.connect(self._updateFields) + + @Slot(str) + def _updateRunNumber(self, runNumber): + self.fieldRunNumber.setText(runNumber) + + def updateRunNumber(self, runNumber): + self.signalRunNumberUpdate.emit(runNumber) + + @Slot(float, bool, bool) + def _updateFields(self, smoothingParameter, lss, decreaseParameter): + self.smoothingSlider.field.setValue(smoothingParameter) + self.lssDropdown.setCurrentIndex(lss) + self.decreaseParameterDropdown.setCurrentIndex(decreaseParameter) + + def updateFields(self, smoothingParameter, lss, decreaseParameter): + self.signalUpdateFields.emit(smoothingParameter, lss, decreaseParameter) + + @Slot() + def emitValueChange(self): + # verify the fields before recalculation + try: + smoothingValue = self.smoothingSlider.field.value() + lss = self.lssDropdown.currentIndex() == "True" + decreaseParameter = self.decreaseParameterDropdown.currentIndex == "True" + peakWindowClippingSize = int(self.peakWindowClippingSize.field.text()) + except ValueError as e: + QMessageBox.warning( + self, + "Invalid Peak Parameters", + f"Smoothing or peak window clipping size is invalid: {str(e)}", + QMessageBox.Ok, + ) + return + self.signalValueChanged.emit(smoothingValue, lss, decreaseParameter, peakWindowClippingSize) + + def updateWorkspaces(self, diffractionWorkspace, artificialNormWorkspace): + self.diffractionWorkspace = diffractionWorkspace + self.artificialNormWorkspace = artificialNormWorkspace + self._updateGraphs() + + def _updateGraphs(self): + # get the updated workspaces and optimal graph grid + diffractionWorkspace = mtd[self.diffractionWorkspace] + artificialNormWorkspace = mtd[self.artificialNormWorkspace] + numGraphs = diffractionWorkspace.getNumberHistograms() + nrows, ncols = self._optimizeRowsAndCols(numGraphs) + + # now re-draw the figure + self.figure.clear() + for i in range(numGraphs): + ax = self.figure.add_subplot(nrows, ncols, i + 1, projection="mantid") + ax.plot(diffractionWorkspace, wkspIndex=i, label="Diffcal Data", normalize_by_bin_width=True) + ax.plot( + artificialNormWorkspace, + wkspIndex=i, + label="Artificial Normalization Data", + normalize_by_bin_width=True, + linestyle="--", + ) + ax.legend() + ax.tick_params(direction="in") + ax.set_title(f"Group ID: {i + 1}") + # fill in the discovered peaks for easier viewing + x, y, _, _ = get_spectrum(diffractionWorkspace, i, normalize_by_bin_width=True) + # for each detected peak in this group, shade in the peak region + + # resize window and redraw + self.setMinimumHeight(self.initialLayoutHeight + int(self.figure.get_size_inches()[1] * self.figure.dpi)) + self.canvas.draw() + + def _optimizeRowsAndCols(self, numGraphs): + # Get best size for layout + sqrtSize = int(numGraphs**0.5) + if sqrtSize == numGraphs**0.5: + rowSize = sqrtSize + colSize = sqrtSize + elif numGraphs <= ((sqrtSize + 1) * sqrtSize): + rowSize = sqrtSize + colSize = sqrtSize + 1 + else: + rowSize = sqrtSize + 1 + colSize = sqrtSize + 1 + return rowSize, colSize + + @Slot(bool) + def setEnableRecalculateButton(self, enable): + self.recalculationButton.setEnabled(enable) + + def disableRecalculateButton(self): + self.signalUpdateRecalculationButton.emit(False) + + def enableRecalculateButton(self): + self.signalUpdateRecalculationButton.emit(True) + + def verify(self): + # TODO what needs to be verified? + return True diff --git a/src/snapred/ui/view/reduction/ReductionRequestView.py b/src/snapred/ui/view/reduction/ReductionRequestView.py index be181e2b4..03d1e40b6 100644 --- a/src/snapred/ui/view/reduction/ReductionRequestView.py +++ b/src/snapred/ui/view/reduction/ReductionRequestView.py @@ -136,6 +136,10 @@ def clearRunNumbers(self): def verify(self): currentText = self.runNumberDisplay.toPlainText() runNumbers = [num.strip() for num in currentText.split("\n") if num.strip()] + + if not runNumbers: + raise ValueError("Please enter at least one run number.") + for runNumber in runNumbers: if not runNumber.isdigit(): raise ValueError( diff --git a/src/snapred/ui/widget/TrueFalseDropDown.py b/src/snapred/ui/widget/TrueFalseDropDown.py new file mode 100644 index 000000000..9a0fe6594 --- /dev/null +++ b/src/snapred/ui/widget/TrueFalseDropDown.py @@ -0,0 +1,33 @@ +from qtpy.QtWidgets import QComboBox, QVBoxLayout, QWidget + + +class TrueFalseDropDown(QWidget): + def __init__(self, label, parent=None): + super(TrueFalseDropDown, self).__init__(parent) + self.setStyleSheet("background-color: #F5E9E2;") + self._label = label + + self.dropDown = QComboBox() + self._initItems() + + layout = QVBoxLayout() + layout.addWidget(self.dropDown) + self.setLayout(layout) + + def _initItems(self): + self.dropDown.clear() + self.dropDown.addItem(self._label) + self.dropDown.addItems(["True", "False"]) + self.dropDown.model().item(0).setEnabled(False) + self.dropDown.setCurrentIndex(1) + + def currentIndex(self): + # Subtract 1 because the label is considered an index + return self.dropDown.currentIndex() - 1 + + def setCurrentIndex(self, index): + # Add 1 to skip the label + self.dropDown.setCurrentIndex(index + 1) + + def currentText(self): + return self.dropDown.currentText() diff --git a/src/snapred/ui/workflow/ReductionWorkflow.py b/src/snapred/ui/workflow/ReductionWorkflow.py index 2d28bbc2b..62f4e841e 100644 --- a/src/snapred/ui/workflow/ReductionWorkflow.py +++ b/src/snapred/ui/workflow/ReductionWorkflow.py @@ -3,14 +3,18 @@ from qtpy.QtCore import Slot from snapred.backend.dao.request import ( + CreateArtificialNormalizationRequest, ReductionExportRequest, ReductionRequest, ) +from snapred.backend.dao.response.ArtificialNormResponse import ArtificialNormResponse +from snapred.backend.dao.response.ReductionResponse import ReductionResponse from snapred.backend.dao.SNAPResponse import ResponseCode, SNAPResponse from snapred.backend.error.ContinueWarning import ContinueWarning from snapred.backend.log.logger import snapredLogger from snapred.meta.decorators.ExceptionToErrLog import ExceptionToErrLog from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName +from snapred.ui.view.reduction.ArtificialNormalizationView import ArtificialNormalizationView from snapred.ui.view.reduction.ReductionRequestView import ReductionRequestView from snapred.ui.view.reduction.ReductionSaveView import ReductionSaveView from snapred.ui.workflow.WorkflowBuilder import WorkflowBuilder @@ -22,7 +26,6 @@ class ReductionWorkflow(WorkflowImplementer): def __init__(self, parent=None): super().__init__(parent) - self._reductionRequestView = ReductionRequestView( parent=parent, populatePixelMaskDropdown=self._populatePixelMaskDropdown, @@ -33,6 +36,10 @@ def __init__(self, parent=None): self._reductionRequestView.enterRunNumberButton.clicked.connect(lambda: self._populatePixelMaskDropdown()) self._reductionRequestView.pixelMaskDropdown.dropDown.view().pressed.connect(self._onPixelMaskSelection) + self._artificialNormalizationView = ArtificialNormalizationView( + parent=parent, + ) + self._reductionSaveView = ReductionSaveView( parent=parent, ) @@ -50,11 +57,17 @@ def __init__(self, parent=None): "Reduction", continueAnywayHandler=self._continueAnywayHandler, ) + .addNode( + self._continueWithNormalization, + self._artificialNormalizationView, + "Artificial Normalization", + ) .addNode(self._nothing, self._reductionSaveView, "Save") .build() ) self._reductionRequestView.retainUnfocusedDataCheckbox.checkedChanged.connect(self._enableConvertToUnits) + self._artificialNormalizationView.signalValueChanged.connect(self.onArtificialNormalizationValueChange) def _enableConvertToUnits(self): state = self._reductionRequestView.retainUnfocusedDataCheckbox.isChecked() @@ -149,7 +162,7 @@ def _triggerReduction(self, workflowPresenter): ) response = self.request(path="reduction/", payload=request_) - if response.code == ResponseCode.OK: + if isinstance(response.data, ReductionResponse): record, unfocusedData = response.data.record, response.data.unfocusedData # .. update "save" panel message: @@ -171,6 +184,10 @@ def _triggerReduction(self, workflowPresenter): if unfocusedData is not None: self.outputs.append(unfocusedData) + elif isinstance(response.data, ArtificialNormResponse): + self._artificialNormalizationView.updateRunNumber(runNumber) + self._artificialNormalization(workflowPresenter, response.data, runNumber) + return self.responses[-1] # Note that the run number is deliberately not deleted from the run numbers list. # Almost certainly it should be moved to a "completed run numbers" list. @@ -178,7 +195,102 @@ def _triggerReduction(self, workflowPresenter): # _before_ transitioning to the "save" panel. # TODO: make '_clearWorkspaces' a public method (i.e make this combination a special `cleanup` method). self._clearWorkspaces(exclude=self.outputs, clearCachedWorkspaces=True) + workflowPresenter.advanceWorkflow() + return self.responses[-1] + + def _artificialNormalization(self, workflowPresenter, responseData, runNumber): + view = workflowPresenter.widget.tabView # noqa: F841 + try: + # Handle artificial normalization request here + request_ = CreateArtificialNormalizationRequest( + runNumber=runNumber, + useLiteMode=self._reductionRequestView.liteModeToggle.field.getState(), + peakWindowClippingSize=int(self._artificialNormalizationView.peakWindowClippingSize.field.text()), + smoothingParameter=self._artificialNormalizationView.smoothingSlider.field.value(), + decreaseParameter=self._artificialNormalizationView.decreaseParameterDropdown.currentIndex() == 1, + lss=self._artificialNormalizationView.lssDropdown.currentIndex() == 1, + diffractionWorkspace=responseData.diffractionWorkspace, + ) + response = self.request(path="reduction/artificialNormalization", payload=request_) + # Update workspaces in the artificial normalization view + diffractionWorkspace = responseData.diffractionWorkspace + artificialNormWorkspace = response.data + + if diffractionWorkspace and artificialNormWorkspace: + self._artificialNormalizationView.updateWorkspaces(diffractionWorkspace, artificialNormWorkspace) + else: + print(f"Error: Workspaces not found in the response: {response.data}") + except Exception as e: # noqa: BLE001 + print(f"Error during artificial normalization request: {e}") + + @Slot(float, bool, bool, int) + def onArtificialNormalizationValueChange(self, smoothingValue, lss, decreaseParameter, peakWindowClippingSize): + self._artificialNormalizationView.disableRecalculateButton() + # Recalculate normalization based on updated values + runNumber = self._artificialNormalizationView.fieldRunNumber.text() + diffractionWorkspace = self._artificialNormalizationView.diffractionWorkspace + try: + request_ = CreateArtificialNormalizationRequest( + runNumber=runNumber, + useLiteMode=self._reductionRequestView.liteModeToggle.field.getState(), + peakWindowClippingSize=peakWindowClippingSize, + smoothingParameter=smoothingValue, + decreaseParameter=decreaseParameter, + lss=lss, + diffractionWorkspace=diffractionWorkspace, + ) + + response = self.request(path="reduction/artificialNormalization", payload=request_) + artificialNormWorkspace = response.data + # Update the view with new workspaces + self._artificialNormalizationView.updateWorkspaces(diffractionWorkspace, artificialNormWorkspace) + + except Exception as e: # noqa: BLE001 + print(f"Error during recalculation: {e}") + + self._artificialNormalizationView.enableRecalculateButton() + + def _continueWithNormalization(self, workflowPresenter): + # Get the updated normalization workspace from the ArtificialNormalizationView + view = workflowPresenter.widget.tabView # noqa: F841 + artificialNormWorkspace = self._artificialNormalizationView.artificialNormWorkspace + + # Now modify the request to use the artificial normalization workspace and continue the workflow + pixelMasks = self._reconstructPixelMaskNames(self._reductionRequestView.getPixelMasks()) + timestamp = self.request(path="reduction/getUniqueTimestamp").data + + request_ = ReductionRequest( + runNumber=str(self._artificialNormalizationView.fieldRunNumber.text()), + useLiteMode=self._reductionRequestView.liteModeToggle.field.getState(), + timestamp=timestamp, + continueFlags=self.continueAnywayFlags, + pixelMasks=pixelMasks, + keepUnfocused=self._reductionRequestView.retainUnfocusedDataCheckbox.isChecked(), + convertUnitsTo=self._reductionRequestView.convertUnitsDropdown.currentText(), + normalizationWorkspace=artificialNormWorkspace, + ) + + # Re-trigger reduction with the artificial normalization workspace + response = self.request(path="reduction/", payload=request_) + + if response.code == ResponseCode.OK: + # Continue to the save step as before + record, unfocusedData = response.data.record, response.data.unfocusedData + savePath = self.request(path="reduction/getSavePath", payload=record.runNumber).data + self._reductionSaveView.updateContinueAnyway(self.continueAnywayFlags) + self._reductionSaveView.updateSavePath(savePath) + + if ContinueWarning.Type.NO_WRITE_PERMISSIONS not in self.continueAnywayFlags: + self.request(path="reduction/save", payload=ReductionExportRequest(record=record)) + + # Handle output workspaces + self.outputs.extend(record.workspaceNames) + if unfocusedData is not None: + self.outputs.append(unfocusedData) + + # Clear workspaces except the output ones before transitioning to the save panel + self._clearWorkspaces(exclude=self.outputs, clearCachedWorkspaces=True) return self.responses[-1] @property diff --git a/tests/unit/backend/service/test_ReductionService.py b/tests/unit/backend/service/test_ReductionService.py index c8a157f9c..511015810 100644 --- a/tests/unit/backend/service/test_ReductionService.py +++ b/tests/unit/backend/service/test_ReductionService.py @@ -7,6 +7,8 @@ import pydantic import pytest from mantid.simpleapi import ( + ConvertUnits, + CreateWorkspace, DeleteWorkspace, mtd, ) @@ -14,13 +16,16 @@ from snapred.backend.dao.ingredients.ReductionIngredients import ReductionIngredients from snapred.backend.dao.reduction.ReductionRecord import ReductionRecord from snapred.backend.dao.request import ( + CreateArtificialNormalizationRequest, ReductionExportRequest, ReductionRequest, ) from snapred.backend.dao.request.ReductionRequest import Versions +from snapred.backend.dao.response.ArtificialNormResponse import ArtificialNormResponse from snapred.backend.dao.SNAPRequest import SNAPRequest from snapred.backend.dao.state import DetectorState from snapred.backend.dao.state.FocusGroup import FocusGroup +from snapred.backend.dao.state.PixelGroupingParameters import PixelGroupingParameters from snapred.backend.error.ContinueWarning import ContinueWarning from snapred.backend.error.StateValidationException import StateValidationException from snapred.backend.service.ReductionService import ReductionService @@ -350,6 +355,295 @@ def test_validateReduction_no_permissions_and_no_calibrations_second_reentry(sel # and in addition, re-entry for the second continue-anyway check. self.instance.validateReduction(self.request) + def test_validateReduction_with_continueFlags(self): + self.request.continueFlags = None + + self.instance.dataFactoryService.normalizationExists = mock.Mock(return_value=True) + self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=True) + self.instance.checkWritePermissions = mock.Mock(return_value=False) + + with pytest.raises(ContinueWarning) as excInfo: + self.instance.validateReduction(self.request) + + assert excInfo.value.model.flags == ContinueWarning.Type.NO_WRITE_PERMISSIONS + + def test_validateReduction_with_warningMessages(self): + self.instance.dataFactoryService.normalizationExists = mock.Mock(return_value=False) + self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=False) + + with pytest.raises(ContinueWarning) as excInfo: + self.instance.validateReduction(self.request) + + assert "Normalization is missing" in str(excInfo.value) + assert "Diffraction calibration is missing" in str(excInfo.value) + + def test_artificialNormalization(self): + mockAlgo = mock.Mock() + len_wksp = 6 + input_ws_name = mtd.unique_name(prefix="input_ws_") + + mockAlgo.executeRecipe.return_value = input_ws_name + "_artificial_norm" + CreateWorkspace( + OutputWorkspace=input_ws_name, + DataX=[1] * len_wksp, + DataY=[1] * len_wksp, + NSpec=len_wksp, + UnitX="dSpacing", + ) + ConvertUnits( + InputWorkspace=input_ws_name, + OutputWorkspace=input_ws_name, + Target="dSpacing", + ) + with mock.patch( + "snapred.backend.service.ReductionService.ArtificialNormalizationRecipe", return_value=mockAlgo + ): + request = CreateArtificialNormalizationRequest( + runNumber="12345", + useLiteMode=True, + peakWindowClippingSize=1.0, + smoothingParameter=0.5, + decreaseParameter=True, + lss=True, + diffractionWorkspace=input_ws_name, + ) + + response = self.instance.artificialNormalization(request) + + assert response == input_ws_name + "_artificial_norm" + + mockAlgo.executeRecipe.assert_called_once_with(InputWorkspace=input_ws_name, Ingredients=mock.ANY) + + def test_loadAllGroupings_with_exception(self): + mock_grouping_map = mock.Mock() + mock_grouping_map.getMap.return_value = {} + + self.instance.dataFactoryService.getGroupingMap = mock.Mock( + side_effect=StateValidationException("Invalid State") + ) + + self.instance.dataFactoryService.getDefaultGroupingMap = mock.Mock(return_value=mock_grouping_map) + + result = self.instance.loadAllGroupings(self.request.runNumber, self.request.useLiteMode) + + mock_grouping_map.getMap.assert_called_once_with(self.request.useLiteMode) + + assert result == {"focusGroups": [], "groupingWorkspaces": []} + + def test_saveReductionPath(self): + mockPixelGroupingParams = { + "group1": [mock.Mock(spec=PixelGroupingParameters), mock.Mock(spec=PixelGroupingParameters)] + } + + mockRecord = ReductionRecord( + runNumber="123456", + useLiteMode=True, + timestamp=123456.789, + pixelGroupingParameters=mockPixelGroupingParams, + workspaceNames=["workspace1", "workspace2"], + ) + + mockRequest = ReductionExportRequest(record=mockRecord) + + with ( + mock.patch.object(self.instance.dataExportService, "exportReductionRecord") as mockExportRecord, + mock.patch.object(self.instance.dataExportService, "exportReductionData") as mockExportData, + ): + self.instance.saveReduction(mockRequest) + + mockExportRecord.assert_called_once_with(mockRecord) + mockExportData.assert_called_once_with(mockRecord) + + def test_loadReductionPath(self): + with pytest.raises(NotImplementedError): + self.instance.loadReduction("someState", 123456.789) + + def test_prepCombinedMask(self): + self.maskWS1 = mock.Mock() + self.maskWS2 = mock.Mock() + + masks = [self.maskWS1, self.maskWS2] + + combinedMaskName = ( # noqa: F841 + wng.reductionPixelMask() + .runNumber(self.request.runNumber) + .timestamp(self.instance.getUniqueTimestamp()) + .build() + ) + + with mock.patch.object(self.instance.mantidSnapper, "BinaryOperateMasks") as mockBinaryOperateMasks: + self.instance.prepCombinedMask( + self.request.runNumber, self.request.useLiteMode, self.request.timestamp, masks + ) + + assert mockBinaryOperateMasks.call_count == len(masks) + + def test_checkWritePermissions_fails(self): + with mock.patch.object(self.instance.dataExportService, "checkWritePermissions", return_value=False): + assert not self.instance.checkWritePermissions("123456") + + def test_createReductionRecord_missing_normalization_and_calibration(self): + self.request.continueFlags = ( + ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION | ContinueWarning.Type.MISSING_NORMALIZATION + ) + + mockIngredients = ReductionIngredients( + runNumber="12345", + useLiteMode=True, + timestamp=123456789.0, + pixelGroups=[], + smoothingParameter=0.1, + calibrantSamplePath="path/to/calibrant", + peakIntensityThreshold=0.05, + keepUnfocused=True, + convertUnitsTo="TOF", + ) + mockWorkspaceNames = ["ws1", "ws2"] + + result = self.instance._createReductionRecord(self.request, mockIngredients, mockWorkspaceNames) + + assert result.normalization is None + assert result.calibration is None + assert result.workspaceNames == mockWorkspaceNames + + def test_validateReduction_with_artificialNormalization_and_no_permissions(self): + self.request.artificialNormalization = mock.Mock() + self.instance.checkWritePermissions = mock.Mock(return_value=False) + + with pytest.raises(ContinueWarning) as excInfo: + self.instance.validateReduction(self.request) + + assert excInfo.value.model.flags == ContinueWarning.Type.NO_WRITE_PERMISSIONS + + def test_validateReduction_with_continueFlags_xor_operation(self): + self.request.artificialNormalization = mock.Mock() + self.request.continueFlags = ContinueWarning.Type.NO_WRITE_PERMISSIONS + + self.instance.checkWritePermissions = mock.Mock(return_value=True) + + self.instance.validateReduction(self.request) + + def test_fetchReductionGroceries_with_artificialNormalization(self): + self.request.artificialNormalization = "artificial_norm_ws" + + self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + + mock_grocery_clerk = mock.Mock() + self.instance.groceryClerk = mock_grocery_clerk + + mock_grocery_clerk.name.return_value = mock_grocery_clerk + mock_grocery_clerk.diffcal_table.return_value = mock_grocery_clerk + mock_grocery_clerk.useLiteMode.return_value = mock_grocery_clerk + mock_grocery_clerk.add.return_value = mock_grocery_clerk + mock_grocery_clerk.buildDict.return_value = {"key1": "value1"} + + mock_grocery_service = mock.Mock() + self.instance.groceryService = mock_grocery_service + mock_grocery_service.fetchGroceryList.return_value = ["workspace1"] + + self.instance.fetchReductionGroceries(self.request) + + mock_grocery_clerk.name.assert_any_call("inputWorkspace") + mock_grocery_clerk.name.assert_any_call("diffcalWorkspace") + mock_grocery_clerk.diffcal_table.assert_called_once_with(self.request.runNumber, 1) + mock_grocery_clerk.useLiteMode.assert_called_once_with(self.request.useLiteMode) + mock_grocery_clerk.add.assert_called_once() + + def test_fetchReductionGroceries_with_missing_normalization(self): + self.request.continueFlags = ContinueWarning.Type.UNSET + self.request.artificialNormalization = None + + self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=2) + + mock_grocery_clerk = mock.Mock() + self.instance.groceryClerk = mock_grocery_clerk + mock_grocery_clerk.name.return_value = mock_grocery_clerk + mock_grocery_clerk.normalization.return_value = mock_grocery_clerk + mock_grocery_clerk.useLiteMode.return_value = mock_grocery_clerk + mock_grocery_clerk.add.return_value = mock_grocery_clerk + mock_grocery_clerk.buildDict.return_value = {"key1": "value1"} + + mock_grocery_service = mock.Mock() + self.instance.groceryService = mock_grocery_service + mock_grocery_service.fetchGroceryDict.return_value = {"inputWorkspace": "workspace1"} + + self.instance.fetchReductionGroceries(self.request) + + self.instance.dataFactoryService.getThisOrLatestNormalizationVersion.assert_called_once_with( + self.request.runNumber, self.request.useLiteMode + ) + + mock_grocery_clerk.name.assert_any_call("normalizationWorkspace") + mock_grocery_clerk.normalization.assert_called_once_with(self.request.runNumber, 2) + mock_grocery_clerk.useLiteMode.assert_called_once_with(self.request.useLiteMode) + mock_grocery_clerk.add.assert_called_once() + + def test_fetchReductionGroceries_with_calibration_and_missing_normalization(self): + self.request.continueFlags = ContinueWarning.Type.MISSING_NORMALIZATION + self.request.artificialNormalization = None + + self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=None) + + mock_grocery_clerk = mock.Mock() + self.instance.groceryClerk = mock_grocery_clerk + mock_grocery_clerk.name.return_value = mock_grocery_clerk + mock_grocery_clerk.diffcal_output.return_value = mock_grocery_clerk + mock_grocery_clerk.useLiteMode.return_value = mock_grocery_clerk + mock_grocery_clerk.unit.return_value = mock_grocery_clerk + mock_grocery_clerk.group.return_value = mock_grocery_clerk + mock_grocery_clerk.buildDict.return_value = {"key1": "value1"} + + mock_grocery_service = mock.Mock() + self.instance.groceryService = mock_grocery_service + mock_grocery_service.fetchGroceryDict.return_value = {"diffractionWorkspace": "diffraction_ws"} + result = self.instance.fetchReductionGroceries(self.request) + + mock_grocery_clerk.name.assert_called_with("diffractionWorkspace") + mock_grocery_clerk.diffcal_output.assert_called_once_with(self.request.runNumber, 1) + mock_grocery_clerk.unit.assert_called_once_with(wng.Units.DSP) + mock_grocery_clerk.group.assert_called_once_with("column") + mock_grocery_clerk.buildDict.assert_called_once() + + assert isinstance(result, ArtificialNormResponse) + assert result.diffractionWorkspace == "diffraction_ws" + + def test_fetchReductionGroceries_creates_versions(self): + self.request.continueFlags = ContinueWarning.Type.UNSET + self.request.artificialNormalization = None + + self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=2) + + mock_grocery_clerk = mock.Mock() + self.instance.groceryClerk = mock_grocery_clerk + mock_grocery_clerk.name.return_value = mock_grocery_clerk + mock_grocery_clerk.useLiteMode.return_value = mock_grocery_clerk + mock_grocery_clerk.add.return_value = mock_grocery_clerk + mock_grocery_clerk.buildDict.return_value = {"key1": "value1"} + + mock_grocery_service = mock.Mock() + self.instance.groceryService = mock_grocery_service + mock_grocery_service.fetchGroceryDict.return_value = {"inputWorkspace": "workspace1"} + + self.instance.fetchReductionGroceries(self.request) + + assert self.request.versions.calibration == 1 + assert self.request.versions.normalization == 2 + + def test_reduction_with_artificial_norm_response(self): + artificial_response = ArtificialNormResponse(diffractionWorkspace="mock_diffraction_ws") + self.instance.fetchReductionGroceries = mock.Mock(return_value=artificial_response) + + self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=True) + self.instance.dataFactoryService.normalizationExists = mock.Mock(return_value=True) + + result = self.instance.reduction(self.request) + + assert result == artificial_response + self.instance.fetchReductionGroceries.assert_called_once_with(self.request) + class TestReductionServiceMasks: @pytest.fixture(autouse=True, scope="class") diff --git a/tests/unit/meta/test_Decorators.py b/tests/unit/meta/test_Decorators.py index 551a73e65..5588bbf0f 100644 --- a/tests/unit/meta/test_Decorators.py +++ b/tests/unit/meta/test_Decorators.py @@ -204,3 +204,46 @@ def testFunc(): assert mockLogger.debug.call_count == 2 assert mockLogger.debug.call_args_list[0][0][0] == "Entering testFunc" assert mockLogger.debug.call_args_list[1][0][0] == "Exiting testFunc" + + +def test_stateValidationExceptionNoTracebackDetails(): + exception = Exception("Test Exception without valid traceback") + + with patch("snapred.backend.error.StateValidationException.logger") as logger_mock: + with patch("traceback.extract_tb", return_value=None): + with pytest.raises(StateValidationException) as excinfo: + raise StateValidationException(exception) + + assert str(excinfo.value) == "Instrument State for given Run Number is invalid! (see logs for details.)" + + logger_mock.error.assert_called_with("Test Exception without valid traceback") + + +def test_stateValidationExceptionWithPartialTracebackDetails(): + exception = Exception("Test Exception with incomplete traceback") + + mock_tb_info = [traceback.FrameSummary(filename=None, lineno=42, name="testFunction")] + + with patch("snapred.backend.error.StateValidationException.logger") as logger_mock: + with patch("traceback.extract_tb", return_value=mock_tb_info): + with pytest.raises(StateValidationException) as excinfo: + raise StateValidationException(exception) + + assert str(excinfo.value) == "Instrument State for given Run Number is invalid! (see logs for details.)" + + logger_mock.error.assert_called_with("Test Exception with incomplete traceback") + + +def test_stateValidationExceptionWithMissingFilePath(): + exception = Exception("Test Exception with missing file path") + + mock_tb_info = [traceback.FrameSummary(filename=None, lineno=42, name="testFunction")] + + with patch("snapred.backend.error.StateValidationException.logger") as logger_mock: + with patch("traceback.extract_tb", return_value=mock_tb_info): + with pytest.raises(StateValidationException) as excinfo: + raise StateValidationException(exception) + + assert str(excinfo.value) == "Instrument State for given Run Number is invalid! (see logs for details.)" + + logger_mock.error.assert_called_with("Test Exception with missing file path")