diff --git a/tools/ARIAtools/extractProduct.py b/tools/ARIAtools/extractProduct.py index 1d226168..14bda594 100755 --- a/tools/ARIAtools/extractProduct.py +++ b/tools/ARIAtools/extractProduct.py @@ -1295,7 +1295,12 @@ def finalize_metadata(outname, bbox_bounds, dem_bounds, prods_TOTbbox, dem, from scipy.interpolate import RegularGridInterpolator # get final shape - arrres = gdal.Open(dem.GetDescription()) + # MG: add option to pass dem path as string + if isinstance(dem, str): + arrres = gdal.Open(dem) + else: + # for gdal instance + arrres = gdal.Open(dem.GetDescription()) arrshape = [arrres.RasterYSize, arrres.RasterXSize] ref_geotrans = arrres.GetGeoTransform() arrres = [abs(ref_geotrans[1]), abs(ref_geotrans[-1])] diff --git a/tools/ARIAtools/sequential_stitching.py b/tools/ARIAtools/sequential_stitching.py index 5a0e9e2d..6042fab7 100644 --- a/tools/ARIAtools/sequential_stitching.py +++ b/tools/ARIAtools/sequential_stitching.py @@ -91,6 +91,7 @@ def stitch_unwrapped_frames(input_unw_files: List[str], # Loop through sorted frames, and stitch neighboring frames for i, (ix1, ix2) in enumerate(zip(sorted_ix[:-1], sorted_ix[1:])): if verbose: + print(50*'*') print( 'Frame-1:', unw_attr_dicts[ix1]['PATH'].split('"')[1].split('/')[-1]) @@ -230,11 +231,12 @@ def stitch_unw2frames(unw_data1: NDArray, conn_data1: NDArray, rdict1: dict, # Correction methods: mean difference, 2pi integer cycles if correction_method == 'cycle2pi': correction = cycles2pi - elif correction_method == 'meandiff': + elif correction_method == 'meanoff': correction = diff + range_correction = False else: raise ValueError(f'Wrong correction method {correction_method}, ', - 'Select one of available: "cycle2pi", "meandiff"') + 'Select one of available: "cycle2pi", "meanoff"') # add range correction if range_correction: @@ -252,6 +254,7 @@ def stitch_unw2frames(unw_data1: NDArray, conn_data1: NDArray, rdict1: dict, conn_reverse = conn_reverse[ik] for pair in conn_reverse: + print('Going backward!') if verbose else None diff, cycles2pi, range_corr = _integer_2pi_cycles( unw1=unw_data1[box_1], concom1=conn_data1[box_1], @@ -267,6 +270,7 @@ def stitch_unw2frames(unw_data1: NDArray, conn_data1: NDArray, rdict1: dict, correction = cycles2pi elif correction_method == 'meanoff': correction = diff + range_correction = False else: raise ValueError(f'Wrong correction method {correction_method}, ', 'Select one of available: "cycle2pi", "meanoff"') @@ -443,14 +447,20 @@ def _integer_2pi_cycles(unw1: NDArray, concom1: NDArray, ix1: np.float32, # dimensions for comparison idx = np.where((concom1 == ix1) & (concom2 == ix2)) - n_points = np.count_nonzero(unw1[idx] - unw2[idx]) - diff_value = np.nanmean(unw1[idx] - unw2[idx]) - std_value = np.nanstd(unw1[idx] - unw2[idx]) + diff = unw1[idx] - unw2[idx] + + # Masked array to array + if np.ma.isMaskedArray(diff): + diff = diff.data + + median_diff = np.nanmedian(diff) + std_value = np.nanstd(diff) + n_points = np.count_nonzero(diff) # Number of 2pi integer jumps - num_jump = (np.abs(diff_value) + np.pi) // (2.*np.pi) + num_jump = (np.abs(median_diff) + np.pi) // (2.*np.pi) - if diff_value < 0: + if median_diff < 0: num_jump *= -1 correction2pi = 2. * np.pi * num_jump @@ -461,19 +471,25 @@ def _integer_2pi_cycles(unw1: NDArray, concom1: NDArray, ix1: np.float32, else: range_corr = 0 + # Note: range correctio sometimes gives oposite sign of + # correction, and add half or one cycle more. + # not sure, why that happens?? below is a hardcoded solution + if np.abs(median_diff - (correction2pi + range_corr)) > 3.14: + range_corr *= -1 + if print_msg: print( - f' Frame-1 component: {ix1} - Frame-2 component: {ix2}\n' + f' Frame-1 component: {ix1} - Frame-2 component: {ix2}\n' f' Number of points: {n_points}\n' - f' Mean diff: {diff_value:.2f}, std: {std_value:.2f} rad\n' + f' Median diff: {median_diff:.2f}, std: {std_value:.2f} rad\n' f' Number of 2pi cycles: {num_jump}\n' f' Correction2pi: {correction2pi:.2f}') if range_correction: print( - f'Range Corr: {range_corr:.2f} \n', - f'2piCorr + RangeCorr: {correction2pi + range_corr:.2f}\n') + f' Range Corr: {range_corr:.2f} \n', + f' 2piCorr + RangeCorr: {correction2pi + range_corr:.2f}\n') - return diff_value, correction2pi, range_corr + return median_diff, correction2pi, range_corr def _range_correction(unw1: NDArray, @@ -497,8 +513,8 @@ def _range_correction(unw1: NDArray, """ # Wrap unwrapped Phase in Frame-1 and Frame-2 - unw1_wrapped = np.mod(unw1, (2*np.pi))-np.pi - unw2_wrapped = np.mod(unw2, (2*np.pi))-np.pi + unw1_wrapped = np.mod(unw1, (2*np.pi)) - np.pi + unw2_wrapped = np.mod(unw2, (2*np.pi)) - np.pi # Get the difference between wrapped images arr = unw1_wrapped - unw2_wrapped @@ -697,7 +713,7 @@ def product_stitch_sequential(input_unw_files: List[str], input_conncomp_files, correction_method=correction_method, range_correction=range_correction, - direction_N_S=False, + direction_N_S=True, verbose=verbose) # Write @@ -745,12 +761,19 @@ def product_stitch_sequential(input_unw_files: List[str], str(output), format="VRT") # Remove temp files [ii.unlink() for ii in [input, input.with_suffix('.vrt'), + input.with_suffix('.xml'), input.with_suffix('.hdr'), input.with_suffix('.aux.xml')] if ii.exists()] # Mask if mask_file: - mask_array = mask_file.ReadAsArray() + if isinstance(mask_file, str): + mask = gdal.Open(mask_file) + else: + # for gdal instance, from prep_mask + mask = mask_file + + mask_array = mask.ReadAsArray() array = get_GUNW_array(str(output.with_suffix('.vrt'))) if output == output_conn: