Skip to content

Commit

Permalink
Add compatibility with old dataset name convention
Browse files Browse the repository at this point in the history
  • Loading branch information
tobias-liaudat committed Sep 30, 2022
1 parent 72bf515 commit 899a6e3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions wf_psf/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def compute_shape_metrics(
predictions = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size)

# GT data preparation
if dataset_dict is None or 'super_res_stars' not in dataset_dict:
if dataset_dict is None or 'super_res_stars' not in dataset_dict or 'SR_stars' not in dataset_dict:
print('Generating GT super resolved stars from the GT model.')
# Change interpolation parameters for the GT simPSF
interp_pts_per_bin = simPSF_np.interp_pts_per_bin
Expand All @@ -438,7 +438,10 @@ def compute_shape_metrics(

else:
print('Using super resolved stars from dataset.')
GT_predictions = dataset_dict['super_res_stars']
if 'super_res_stars' in dataset_dict:
GT_predictions = dataset_dict['super_res_stars']
elif 'SR_stars' in dataset_dict:
GT_predictions = dataset_dict['SR_stars']

# Calculate residuals
residuals = np.sqrt(np.mean((GT_predictions - predictions)**2, axis=(1, 2)))
Expand Down

0 comments on commit 899a6e3

Please sign in to comment.