diff --git a/include/neural-graphics-primitives/common.h b/include/neural-graphics-primitives/common.h index f933c2b99..1d6d8b6ec 100644 --- a/include/neural-graphics-primitives/common.h +++ b/include/neural-graphics-primitives/common.h @@ -81,6 +81,14 @@ enum class ERenderMode : int { }; static constexpr const char* RenderModeStr = "AO\0Shade\0Normals\0Positions\0Depth\0Distortion\0Cost\0Slice\0\0"; +enum class ECameraMode : int { + Perspective, + Orthographic, + Environment +}; + +static constexpr const char* CameraModeStr = "Perspective\0Orthographic\0Environment\0\0"; + enum class ERandomMode : int { Random, Halton, diff --git a/include/neural-graphics-primitives/common_device.cuh b/include/neural-graphics-primitives/common_device.cuh index 540d45496..9ff5be99f 100644 --- a/include/neural-graphics-primitives/common_device.cuh +++ b/include/neural-graphics-primitives/common_device.cuh @@ -270,40 +270,75 @@ inline __host__ __device__ Ray pixel_to_ray( bool snap_to_pixel_centers = false, float focus_z = 1.0f, float dof = 0.0f, + const ECameraMode camera_mode = ECameraMode::Perspective, const CameraDistortion& camera_distortion = {}, const float* __restrict__ distortion_data = nullptr, - const Eigen::Vector2i distortion_resolution = Eigen::Vector2i::Zero() + const Eigen::Vector2i distortion_resolution = Eigen::Vector2i::Zero(), + const float dataset_scale = 1.f ) { Eigen::Vector2f offset = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : spp); Eigen::Vector2f uv = (pixel.cast() + offset).cwiseQuotient(resolution.cast()); + const Eigen::Vector3f shift = {parallax_shift.x(), parallax_shift.y(), 0.f}; Eigen::Vector3f dir; - if (camera_distortion.mode == ECameraDistortionMode::FTheta) { - dir = f_theta_undistortion(uv - screen_center, camera_distortion.params, {1000.f, 0.f, 0.f}); - if (dir.x() == 1000.f) { - return {{1000.f, 0.f, 0.f}, {0.f, 0.f, 1.f}}; // return a point outside the aabb so the pixel is not rendered - } - } else if (camera_distortion.mode == ECameraDistortionMode::LatLong) { - dir = latlong_to_dir(uv); - } else { - dir = { + + Eigen::Vector3f head_pos; + if(camera_mode == ECameraMode::Orthographic){ + // 'dataset_scale' argument is only required by the orthographic camera. + // The focal length of Environment and Perspective cameras isn't affected by the change of dataset_scale, + // because all rays originate from the same point + dir = {0.f, 0.f, 1.f}; // Camera forward + head_pos = { (uv.x() - screen_center.x()) * (float)resolution.x() / focal_length.x(), (uv.y() - screen_center.y()) * (float)resolution.y() / focal_length.y(), - 1.0f + 0.0f }; - if (camera_distortion.mode == ECameraDistortionMode::Iterative) { - iterative_camera_undistortion(camera_distortion.params, &dir.x(), &dir.y()); - } + head_pos *= dataset_scale; + head_pos += shift; + dir -= shift / parallax_shift.z(); // we could use focus_z here in the denominator. for now, we pack m_scale in here. } - if (distortion_data) { - dir.head<2>() += read_image<2>(distortion_data, distortion_resolution, uv); + else if(camera_mode == ECameraMode::Environment){ + // Camera convention: XYZ <-> Right Down Front + head_pos = {0.f, 0.f, 0.f}; + const float phi = (uv.y()-0.5) * M_PI; + const float theta = (uv.x()-0.5) * 2.0 * M_PI; + const float cos_phi = std::cos(phi); + dir = { + cos_phi*std::sin(theta), + std::sin(phi), + cos_phi*std::cos(theta) + }; + // Parallax isn't handled + } + else { // Perspective + head_pos = {0.f, 0.f, 0.f}; + if (camera_distortion.mode == ECameraDistortionMode::FTheta) { + dir = f_theta_undistortion(uv - screen_center, camera_distortion.params, {1000.f, 0.f, 0.f}); + if (dir.x() == 1000.f) { + return {{1000.f, 0.f, 0.f}, {0.f, 0.f, 1.f}}; // return a point outside the aabb so the pixel is not rendered + } + } else if (camera_distortion.mode == ECameraDistortionMode::LatLong) { + dir = latlong_to_dir(uv); + } else { + dir = { + (uv.x() - screen_center.x()) * (float)resolution.x() / focal_length.x(), + (uv.y() - screen_center.y()) * (float)resolution.y() / focal_length.y(), + 1.0f + }; + if (camera_distortion.mode == ECameraDistortionMode::Iterative) { + iterative_camera_undistortion(camera_distortion.params, &dir.x(), &dir.y()); + } + } + if (distortion_data) { + dir.head<2>() += read_image<2>(distortion_data, distortion_resolution, uv); + } + head_pos += shift; + dir -= shift / parallax_shift.z(); // we could use focus_z here in the denominator. for now, we pack m_scale in here. } - Eigen::Vector3f head_pos = {parallax_shift.x(), parallax_shift.y(), 0.f}; - dir -= head_pos / parallax_shift.z(); // we could use focus_z here in the denominator. for now, we pack m_scale in here. dir = camera_matrix.block<3, 3>(0, 0) * dir; - Eigen::Vector3f origin = camera_matrix.block<3, 3>(0, 0) * head_pos + camera_matrix.col(3); + if (dof == 0.0f) { return {origin, dir}; } @@ -323,16 +358,51 @@ inline __host__ __device__ Eigen::Vector2f pos_to_pixel( const Eigen::Matrix& camera_matrix, const Eigen::Vector2f& screen_center, const Eigen::Vector3f& parallax_shift, - const CameraDistortion& camera_distortion = {} + const ECameraMode camera_mode, + const CameraDistortion& camera_distortion = {}, + const float dataset_scale = 1.f ) { - // Express ray in terms of camera frame - Eigen::Vector3f head_pos = {parallax_shift.x(), parallax_shift.y(), 0.f}; - Eigen::Vector3f origin = camera_matrix.block<3, 3>(0, 0) * head_pos + camera_matrix.col(3); + // We get 'pos' as an input. We have pos = origin + alpha*dir, with unknown alpha + // tmp_dir = R^-1*(pos-t) + Eigen::Vector3f tmp_dir = camera_matrix.block<3, 3>(0, 0).inverse() * (pos - camera_matrix.col(3)); + const Eigen::Vector3f shift = {parallax_shift.x(), parallax_shift.y(), 0.f}; + + if(camera_mode == ECameraMode::Orthographic){ + // head_pos = {..., ..., 0} + // head_dir = {0,0,1} + // dir = R*(head_dir-shift/z) + // origin = R*(head_pos+shift) + t + tmp_dir -= shift; + const Eigen::Vector3f head_dir_minus_shift = Eigen::Vector3f(0.f, 0.f, 1.f) - shift/parallax_shift.z(); + Eigen::Vector3f head_pos = tmp_dir - tmp_dir.z() * head_dir_minus_shift; // Gives head_pos.z=0 since head_dir_minus_shift.z=1 + head_pos /= dataset_scale; + return { + head_pos.x() * focal_length.x() + screen_center.x() * resolution.x(), + head_pos.y() * focal_length.y() + screen_center.y() * resolution.y(), + }; + } + + if(camera_mode == ECameraMode::Environment){ + // Parallax isn't handled + // head_dir = {..., ..., ...} with ||head_dir|| = 1 + // dir = R*head_dir + // origin = t + tmp_dir = tmp_dir.normalized(); + const float phi = std::asin(tmp_dir.y()); + const float theta = std::atan2(tmp_dir.x(), tmp_dir.z()); + return { + (0.5f + theta / (2.0*M_PI)) * resolution.x(), + (0.5f + phi / M_PI) * resolution.x(), + }; + } - Eigen::Vector3f dir = pos - origin; - dir = camera_matrix.block<3, 3>(0, 0).inverse() * dir; - dir /= dir.z(); - dir += head_pos / parallax_shift.z(); + // Perspective + // head_dir = {..., ..., 1} + // dir = R*(head_dir-shift/z) + // origin = R*shift + t + tmp_dir -= shift; + tmp_dir /= tmp_dir.z(); + Eigen::Vector3f dir = tmp_dir + shift / parallax_shift.z(); // Maintains dir.z=1 because shift.z=0 if (camera_distortion.mode == ECameraDistortionMode::Iterative) { float du, dv; @@ -362,7 +432,9 @@ inline __host__ __device__ Eigen::Vector2f motion_vector_3d( const Eigen::Vector3f& parallax_shift, const bool snap_to_pixel_centers, const float depth, - const CameraDistortion& camera_distortion = {} + const ECameraMode camera_mode, + const CameraDistortion& camera_distortion = {}, + const float dataset_scale = 1.f ) { Ray ray = pixel_to_ray( sample_index, @@ -375,9 +447,11 @@ inline __host__ __device__ Eigen::Vector2f motion_vector_3d( snap_to_pixel_centers, 1.0f, 0.0f, + camera_mode, camera_distortion, nullptr, - Eigen::Vector2i::Zero() + Eigen::Vector2i::Zero(), + dataset_scale ); Eigen::Vector2f prev_pixel = pos_to_pixel( @@ -387,7 +461,9 @@ inline __host__ __device__ Eigen::Vector2f motion_vector_3d( prev_camera, screen_center, parallax_shift, - camera_distortion + camera_mode, + camera_distortion, + dataset_scale ); return prev_pixel - (pixel.cast() + ld_random_pixel_offset(sample_index)); diff --git a/include/neural-graphics-primitives/testbed.h b/include/neural-graphics-primitives/testbed.h index d8d830642..0bca066d8 100644 --- a/include/neural-graphics-primitives/testbed.h +++ b/include/neural-graphics-primitives/testbed.h @@ -156,7 +156,9 @@ class Testbed { int show_accel, float cone_angle_constant, ERenderMode render_mode, - cudaStream_t stream + ECameraMode camera_mode, + cudaStream_t stream, + float dataset_scale ); uint32_t trace( @@ -469,6 +471,7 @@ class Testbed { float m_bounding_radius = 1; float m_exposure = 0.f; + ECameraMode m_camera_mode = ECameraMode::Perspective; ERenderMode m_render_mode = ERenderMode::Shade; EMeshRenderMode m_mesh_render_mode = EMeshRenderMode::VertexNormals; diff --git a/src/python_api.cu b/src/python_api.cu index 02bcbec4e..8554afec3 100644 --- a/src/python_api.cu +++ b/src/python_api.cu @@ -238,6 +238,13 @@ PYBIND11_MODULE(pyngp, m) { .value("Slice", ERenderMode::Slice) .export_values(); + py::enum_(m, "CameraMode") + .value("Perspective", ECameraMode::Perspective) + .value("Orthographic", ECameraMode::Orthographic) + .value("Environment", ECameraMode::Environment) + .export_values(); + + py::enum_(m, "RandomMode") .value("Random", ERandomMode::Random) .value("Halton", ERandomMode::Halton) @@ -423,6 +430,7 @@ PYBIND11_MODULE(pyngp, m) { .def_readwrite("shall_train_network", &Testbed::m_train_network) .def_readwrite("render_groundtruth", &Testbed::m_render_ground_truth) .def_readwrite("render_mode", &Testbed::m_render_mode) + .def_readwrite("camera_mode", &Testbed::m_camera_mode) .def_readwrite("slice_plane_z", &Testbed::m_slice_plane_z) .def_readwrite("dof", &Testbed::m_dof) .def_readwrite("autofocus", &Testbed::m_autofocus) diff --git a/src/testbed.cu b/src/testbed.cu index 271a166fe..c569e5724 100644 --- a/src/testbed.cu +++ b/src/testbed.cu @@ -845,6 +845,7 @@ void Testbed::imgui() { ImGui::Checkbox("Autofocus", &m_autofocus); if (ImGui::TreeNode("Advanced camera settings")) { + accum_reset |= ImGui::Combo("Camera mode", (int*)&m_camera_mode, CameraModeStr); accum_reset |= ImGui::SliderFloat2("Screen center", &m_screen_center.x(), 0.f, 1.f); accum_reset |= ImGui::SliderFloat2("Parallax shift", &m_parallax_shift.x(), -1.f, 1.f); accum_reset |= ImGui::SliderFloat("Slice / focus depth", &m_slice_plane_z, -m_bounding_radius, m_bounding_radius); @@ -2505,7 +2506,9 @@ __global__ void dlss_prep_kernel( const float prev_view_dist, const Vector2f image_pos, const Vector2f prev_image_pos, - const Vector2i image_resolution + const Vector2i image_resolution, + const ECameraMode camera_mode, + const float dataset_scale = 1.f ) { uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; @@ -2543,7 +2546,9 @@ __global__ void dlss_prep_kernel( parallax_shift, snap_to_pixel_centers, depth, - camera_distortion + camera_mode, + camera_distortion, + dataset_scale ); surf2Dwrite(make_float2(mvec.x(), mvec.y()), mvec_surface, x_orig * sizeof(float2), y_orig); @@ -2705,7 +2710,9 @@ void Testbed::render_frame(const Matrix& camera_matrix0, const Matr m_prev_scale, m_image.pos, m_image.prev_pos, - m_image.resolution + m_image.resolution, + m_camera_mode, + m_nerf.training.dataset.scale ); render_buffer.set_dlss_sharpening(m_dlss_sharpening); diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu index 222bfd8c6..64b4324b2 100644 --- a/src/testbed_nerf.cu +++ b/src/testbed_nerf.cu @@ -1790,7 +1790,9 @@ __global__ void init_rays_with_payload_kernel_nerf( float* __restrict__ depthbuffer, const float* __restrict__ distortion_data, const Vector2i distortion_resolution, - ERenderMode render_mode + ERenderMode render_mode, + ECameraMode camera_mode, + float dataset_scale ) { uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; @@ -1821,9 +1823,11 @@ __global__ void init_rays_with_payload_kernel_nerf( snap_to_pixel_centers, plane_z, dof, + camera_mode, camera_distortion, distortion_data, - distortion_resolution + distortion_resolution, + dataset_scale ); NerfPayload& payload = payloads[idx]; @@ -1970,7 +1974,9 @@ void Testbed::NerfTracer::init_rays_from_camera( int show_accel, float cone_angle_constant, ERenderMode render_mode, - cudaStream_t stream + ECameraMode camera_mode, + cudaStream_t stream, + float dataset_scale ) { // Make sure we have enough memory reserved to render at the requested resolution size_t n_pixels = (size_t)resolution.x() * resolution.y(); @@ -2000,7 +2006,9 @@ void Testbed::NerfTracer::init_rays_from_camera( depth_buffer, distortion_data, distortion_resolution, - render_mode + render_mode, + camera_mode, + dataset_scale ); m_n_rays_initialized = resolution.x() * resolution.y(); @@ -2263,7 +2271,9 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r m_nerf.show_accel, m_nerf.cone_angle_constant, render_mode, - stream + m_camera_mode, + stream, + m_nerf.training.dataset.scale ); uint32_t n_hit; @@ -2440,7 +2450,7 @@ void Testbed::Nerf::Training::export_camera_extrinsics(const std::string& filena trajectory.emplace_back(frame); } std::ofstream file(filename); - file << std::setw(2) << trajectory << std::endl; + file << std::setw(2) << trajectory << std::endl; } Eigen::Matrix Testbed::Nerf::Training::get_camera_extrinsics(int frame_idx) {