From 7825e054def83b9cd0824cedf64c3c7b43e0d8aa Mon Sep 17 00:00:00 2001 From: Tim Salzmann Date: Sun, 5 Nov 2023 03:53:56 -0800 Subject: [PATCH] Updated dataset pre-processing to include JRDB Challenge Dataset. PiperOrigin-RevId: 579608449 --- README.md | 17 +++ .../config/jrdb_challenge/dataset_params.gin | 111 ++++++++++++++++++ .../config/jrdb_challenge/metrics.gin | 68 +++++++++++ .../config/jrdb_challenge/model_params.gin | 32 +++++ .../config/jrdb_challenge/training_params.gin | 10 ++ human_scene_transformer/data/README.md | 51 +++++--- .../data/jrdb_preprocess_test.py | 73 ++++++++++-- .../data/jrdb_preprocess_train.py | 53 +++++++-- .../data/jrdb_train_detections_to_tracks.py | 31 ++++- .../jrdb/eval_challenge.py | 20 ++-- human_scene_transformer/model/model_params.py | 1 + 11 files changed, 419 insertions(+), 48 deletions(-) create mode 100644 human_scene_transformer/config/jrdb_challenge/dataset_params.gin create mode 100644 human_scene_transformer/config/jrdb_challenge/metrics.gin create mode 100644 human_scene_transformer/config/jrdb_challenge/model_params.gin create mode 100644 human_scene_transformer/config/jrdb_challenge/training_params.gin diff --git a/README.md b/README.md index 5bc046f..843e857 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +:trophy: Winner of the [2023 JRDB Trajectory Prediction Challenge](https://jrdb.erc.monash.edu/leaderboards/trajectory) + # Human Scene Transformer The (Human) Scene Transformer architecture (as described [here](https://arxiv.org/pdf/2309.17209.pdf) and [here)](https://arxiv.org/pdf/2106.08417.pdf) is a general and extendable trajectory prediction framework which threats trajectory prediction as a sequence to sequence problem and models it in a Transformer architecture. @@ -82,6 +84,21 @@ python train.py --model_base_dir=./models/pedestrians_eth --gin_files=..config/ --- +## JRDB Trajectory Prediction Challenge Results +To reproduce our winning results in the [2023 JRDB Trajectory Prediction Challenge](https://jrdb.erc.monash.edu/leaderboards/trajectory): + +- Make sure that you follow the [data pre-processing instructions](/human_scene_transformer/data) and pay special attention to where the instructions differentiate between the JRDB Challenge dataset and the original paper dataset. + +- Download the trained challenge model [here](https://storage.googleapis.com/gresearch/human_scene_transformer/challenge_checkpoint.zip) + +- Run + +``` +python jrdb/eval_challenge.py --model_path= --checkpoint_path=/ckpts/ckpt-20 --dataset_path= --output_path= +``` + +--- + ## Evaluation ### JRDB diff --git a/human_scene_transformer/config/jrdb_challenge/dataset_params.gin b/human_scene_transformer/config/jrdb_challenge/dataset_params.gin new file mode 100644 index 0000000..dc6d310 --- /dev/null +++ b/human_scene_transformer/config/jrdb_challenge/dataset_params.gin @@ -0,0 +1,111 @@ +TRAIN_SCENES = ['bytes-cafe-2019-02-07_0', + 'clark-center-2019-02-28_0', + 'clark-center-intersection-2019-02-28_0', + 'cubberly-auditorium-2019-04-22_0', + 'gates-159-group-meeting-2019-04-03_0', + 'gates-ai-lab-2019-02-08_0', + 'gates-to-clark-2019-02-28_1', + 'hewlett-packard-intersection-2019-01-24_0', + 'huang-basement-2019-01-25_0', + 'huang-lane-2019-02-12_0', + 'memorial-court-2019-03-16_0', + 'meyer-green-2019-03-16_0', + 'packard-poster-session-2019-03-20_0', + 'packard-poster-session-2019-03-20_1', + 'stlc-111-2019-04-19_0', + 'svl-meeting-gates-2-2019-04-08_0', + 'tressider-2019-03-16_0', + 'tressider-2019-03-16_1', + 'cubberly-auditorium-2019-04-22_1_test', + 'discovery-walk-2019-02-28_0_test', + 'food-trucks-2019-02-12_0_test', + 'gates-ai-lab-2019-04-17_0_test', + 'gates-foyer-2019-01-17_0_test', + 'gates-to-clark-2019-02-28_0_test', + 'hewlett-class-2019-01-23_1_test', + 'huang-2-2019-01-25_1_test', + 'indoor-coupa-cafe-2019-02-06_0_test', + 'lomita-serra-intersection-2019-01-30_0_test', + 'nvidia-aud-2019-01-25_0_test', + 'nvidia-aud-2019-04-18_1_test', + 'outdoor-coupa-cafe-2019-02-06_0_test', + 'quarry-road-2019-02-28_0_test', + 'stlc-111-2019-04-19_1_test', + 'stlc-111-2019-04-19_2_test', + 'tressider-2019-04-26_0_test', + 'tressider-2019-04-26_1_test', + 'clark-center-2019-02-28_1', + 'forbes-cafe-2019-01-22_0', + 'gates-basement-elevators-2019-01-17_1', + 'huang-2-2019-01-25_0', + 'jordan-hall-2019-04-22_0', + 'nvidia-aud-2019-04-18_0', + 'packard-poster-session-2019-03-20_2', + 'svl-meeting-gates-2-2019-04-08_1', + 'tressider-2019-04-26_2', + 'discovery-walk-2019-02-28_1_test', + 'gates-basement-elevators-2019-01-17_0_test', + 'hewlett-class-2019-01-23_0_test', + 'huang-intersection-2019-01-22_0_test', + 'meyer-green-2019-03-16_1_test', + 'nvidia-aud-2019-04-18_2_test', + 'serra-street-2019-01-30_0_test', + 'tressider-2019-03-16_2_test', + 'tressider-2019-04-26_3_test'] + +TEST_SCENES = [ + 'cubberly-auditorium-2019-04-22_1_test', + 'discovery-walk-2019-02-28_0_test', + 'discovery-walk-2019-02-28_1_test', + 'food-trucks-2019-02-12_0_test', + 'gates-ai-lab-2019-04-17_0_test', + 'gates-basement-elevators-2019-01-17_0_test', + 'gates-foyer-2019-01-17_0_test', + 'gates-to-clark-2019-02-28_0_test', + 'hewlett-class-2019-01-23_0_test', + 'hewlett-class-2019-01-23_1_test', + 'huang-2-2019-01-25_1_test', + 'huang-intersection-2019-01-22_0_test', + 'indoor-coupa-cafe-2019-02-06_0_test', + 'lomita-serra-intersection-2019-01-30_0_test', + 'meyer-green-2019-03-16_1_test', + 'nvidia-aud-2019-01-25_0_test', + 'nvidia-aud-2019-04-18_1_test', + 'nvidia-aud-2019-04-18_2_test', + 'outdoor-coupa-cafe-2019-02-06_0_test', + 'quarry-road-2019-02-28_0_test', + 'serra-street-2019-01-30_0_test', + 'stlc-111-2019-04-19_1_test', + 'stlc-111-2019-04-19_2_test', + 'tressider-2019-03-16_2_test', + 'tressider-2019-04-26_0_test', + 'tressider-2019-04-26_1_test', + 'tressider-2019-04-26_3_test', +] + + +JRDBDatasetParams.path = + +JRDBDatasetParams.train_scenes = %TRAIN_SCENES +JRDBDatasetParams.eval_scenes = %TEST_SCENES +JRDBDatasetParams.features = [ + 'agents/position', + 'agents/keypoints', + 'robot/position', + 'robot/orientation', + 'scene/pc' + ] + +JRDBDatasetParams.train_split = (0., 1.0) +JRDBDatasetParams.eval_split = (0., 1.0) + + +JRDBDatasetParams.num_history_steps = 11 +JRDBDatasetParams.num_steps = 24 +JRDBDatasetParams.num_agents = 16 +JRDBDatasetParams.timestep = 0.4 + +JRDBDatasetParams.subsample = 6 +JRDBDatasetParams.num_pointcloud_points = 512 + +JRDBDatasetParams.min_distance_to_robot = 50.0 \ No newline at end of file diff --git a/human_scene_transformer/config/jrdb_challenge/metrics.gin b/human_scene_transformer/config/jrdb_challenge/metrics.gin new file mode 100644 index 0000000..1779619 --- /dev/null +++ b/human_scene_transformer/config/jrdb_challenge/metrics.gin @@ -0,0 +1,68 @@ +# All available metrics. +min_ade/metrics.ade.MinADE.cutoff_seconds = None +min_ade1s/metrics.ade.MinADE.cutoff_seconds = 1.0 +min_ade2s/metrics.ade.MinADE.cutoff_seconds = 2.0 +min_ade3s/metrics.ade.MinADE.cutoff_seconds = 3.0 +min_ade4s/metrics.ade.MinADE.cutoff_seconds = 4.0 + +ml_ade/metrics.ade.MLADE.cutoff_seconds = None +ml_ade1s/metrics.ade.MLADE.cutoff_seconds = 1.0 +ml_ade2s/metrics.ade.MLADE.cutoff_seconds = 2.0 +ml_ade3s/metrics.ade.MLADE.cutoff_seconds = 3.0 +ml_ade4s/metrics.ade.MLADE.cutoff_seconds = 4.8 + +pos_nll/metrics.pos_nll.PositionNegativeLogLikelihood.cutoff_seconds = None +pos_nll1s/metrics.pos_nll.PositionNegativeLogLikelihood.cutoff_seconds = 1.0 +pos_nll2s/metrics.pos_nll.PositionNegativeLogLikelihood.cutoff_seconds = 2.0 +pos_nll3s/metrics.pos_nll.PositionNegativeLogLikelihood.cutoff_seconds = 3.0 +pos_nll4s/metrics.pos_nll.PositionNegativeLogLikelihood.cutoff_seconds = 4.8 + +# Training metrics. +get_metrics.train_metrics = { + 'loss': @metrics.Mean, + 'loss_position': @metrics.Mean, + 'loss_orientation': @metrics.Mean, + + 'min_ade': @min_ade/metrics.ade.MinADE, + 'min_ade1s': @min_ade1s/metrics.ade.MinADE, + 'min_ade2s': @min_ade2s/metrics.ade.MinADE, + 'min_ade3s': @min_ade3s/metrics.ade.MinADE, + 'min_ade4s': @min_ade4s/metrics.ade.MinADE, + + 'ml_ade': @ml_ade/metrics.ade.MLADE, + 'ml_ade1s': @ml_ade1s/metrics.ade.MLADE, + 'ml_ade2s': @ml_ade2s/metrics.ade.MLADE, + 'ml_ade3s': @ml_ade3s/metrics.ade.MLADE, + 'ml_ade4s': @ml_ade4s/metrics.ade.MLADE, + + 'pos_nll': @pos_nll/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll1s': @pos_nll1s/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll2s': @pos_nll2s/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll3s': @pos_nll3s/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll4s': @pos_nll4s/metrics.pos_nll.PositionNegativeLogLikelihood, +} + +# Eval metrics. +get_metrics.eval_metrics = { + 'loss': @metrics.Mean, + 'loss_position': @metrics.Mean, + 'loss_orientation': @metrics.Mean, + + 'min_ade': @min_ade/metrics.ade.MinADE, + 'min_ade1s': @min_ade1s/metrics.ade.MinADE, + 'min_ade2s': @min_ade2s/metrics.ade.MinADE, + 'min_ade3s': @min_ade3s/metrics.ade.MinADE, + 'min_ade4s': @min_ade4s/metrics.ade.MinADE, + + 'ml_ade': @ml_ade/metrics.ade.MLADE, + 'ml_ade1s': @ml_ade1s/metrics.ade.MLADE, + 'ml_ade2s': @ml_ade2s/metrics.ade.MLADE, + 'ml_ade3s': @ml_ade3s/metrics.ade.MLADE, + 'ml_ade4s': @ml_ade4s/metrics.ade.MLADE, + + 'pos_nll': @pos_nll/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll1s': @pos_nll1s/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll2s': @pos_nll2s/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll3s': @pos_nll3s/metrics.pos_nll.PositionNegativeLogLikelihood, + 'pos_nll4s': @pos_nll4s/metrics.pos_nll.PositionNegativeLogLikelihood, +} \ No newline at end of file diff --git a/human_scene_transformer/config/jrdb_challenge/model_params.gin b/human_scene_transformer/config/jrdb_challenge/model_params.gin new file mode 100644 index 0000000..d1100f2 --- /dev/null +++ b/human_scene_transformer/config/jrdb_challenge/model_params.gin @@ -0,0 +1,32 @@ +ModelParams.agents_position_key = 'agents/position' +ModelParams.agents_feature_config = { + 'agents/position': @AgentPositionEncoder, + 'agents/keypoints': @AgentKeypointsEncoder, + #'agents/gaze': @Agent2DOrientationEncoder, +} +ModelParams.hidden_size = 128 +ModelParams.feature_embedding_size = 128 +ModelParams.transformer_ff_dim = 128 + +ModelParams.num_heads = 4 +ModelParams.num_modes = 4 +ModelParams.scene_encoder = @PointCloudEncoderLayer +ModelParams.attn_architecture = ( + 'self-attention', + 'self-attention', + 'cross-attention', + 'multimodality_induction', + 'self-attention', + 'self-attention-mode', + 'self-attention', + 'self-attention-mode', + ) +ModelParams.mask_style = "has_historic_data" +ModelParams.drop_prob = 0.1 +ModelParams.prediction_head = @Prediction2DPositionHeadLayer + +ModelParams.num_history_steps = 11 +ModelParams.num_steps = 24 +ModelParams.timestep = 0.4 +# Must be one of the classes in is_hidden_generators.py. +ModelParams.is_hidden_generator = @BPIsHiddenGenerator \ No newline at end of file diff --git a/human_scene_transformer/config/jrdb_challenge/training_params.gin b/human_scene_transformer/config/jrdb_challenge/training_params.gin new file mode 100644 index 0000000..65be7b0 --- /dev/null +++ b/human_scene_transformer/config/jrdb_challenge/training_params.gin @@ -0,0 +1,10 @@ +TrainingParams.batch_size = 64 +TrainingParams.shuffle_buffer_size = 10000 +TrainingParams.total_train_steps = 2e6 +TrainingParams.warmup_steps = 5e4 +TrainingParams.peak_learning_rate = 1e-4 +#TrainingParams.global_clipnorm = 1. +TrainingParams.batches_per_train_step = 25000 +TrainingParams.batches_per_eval_step = 2000 +TrainingParams.eval_every_n_step = 1e4 +TrainingParams.loss = @MultimodalPositionNLLLoss \ No newline at end of file diff --git a/human_scene_transformer/data/README.md b/human_scene_transformer/data/README.md index 9748c18..9817aa2 100644 --- a/human_scene_transformer/data/README.md +++ b/human_scene_transformer/data/README.md @@ -11,23 +11,32 @@ 5. Download and extract [Train Detections](https://jrdb.erc.monash.edu/static/downloads/train_detections.zip) from the JRDB 2019 section to `/detections`. ## Get the Leaderboard Test Set Tracks -Download and extract this leaderboard [3D tracking result](https://jrdb.erc.monash.edu/leaderboards/download/1605) to `/test_dataset/labels/raw_leaderboard/`. Such that you have `/test_dataset/labels/raw_leaderboard/00XX.txt` This is the best available leaderboard tracker at the time the code was developed. -## Get the Robot Odometry Preprocessed Keypoints +### For the JRDB Challenge Dataset +Download and extract this leaderboard [3D tracking result](https://jrdb.erc.monash.edu/leaderboards/download/1762) to `/test_dataset/labels/PiFeNet/`. Such that you have `/test_dataset/labels/PiFeNet/00XX.txt`. -Download the compressed data file [here](https://storage.googleapis.com/gresearch/human_scene_transformer/data.zip). +### For the Orginal Dataset used in the Paper +Download and extract this leaderboard [3D tracking result](https://jrdb.erc.monash.edu/leaderboards/download/1605) to `/test_dataset/labels/ss3d_mot/`. Such that you have `/test_dataset/labels/ss3d_mot/00XX.txt`. This was the best available leaderboard tracker at the time the method was developed. -Extract the files and move them to `/processed/` such that you have `/processed/odoemtry_train`, `/processed/odoemtry_test` and `/processed/labels/labels_3d_keypoints_train/`, `/processed/labels/labels_3d_keypoints_test/`. +## Get the Robot Odometry + +Download the compressed Odometry data file [here](https://storage.googleapis.com/gresearch/human_scene_transformer/odometry.zip). + +Extract the files and move them to `/processed/` such that you have `/processed/odoemtry/train`, `/processed/odoemtry/test`. Alternatively you can extract the robot odometry from the raw rosbags yourself via `extract_robot_odometry_from_rosbag.py`. -## Create Real-World Tracks for Test Data +## Get the Preprocessed Keypoints + +Download the compressed Keypoints data file [here](https://storage.googleapis.com/gresearch/human_scene_transformer/keypoints.zip). + +Extract the files and move them to `/processed/` such that you have `/processed/labels/labels_3d_keypoints/train/`, `/processed/labels/labels_3d_keypoints/test/`. -Adapt `` in `jrdb_train_detections_to_tracks.py` +## Create Real-World Tracks for Train Data -Then run +Run -```python jrdb_train_detections_to_tracks.py``` +```python jrdb_train_detections_to_tracks.py --input_path=``` ## Dataset Folder @@ -48,22 +57,30 @@ You should end up with a dataset folder of the following structure - pointclouds - processed - labels - - labels_3d_keypoints_test - - labels_3d_keypoints_train + - labels_3d_keypoints + - train + - test - labels_detections_3d - - odoemtry_test - - odoemetry_train + - odoemtry + - train + - test ``` ## Generate the Tensorflow Dataset -Adapt `` in `jrdb_preprocess_train.py` and `jrdb_preprocess_test.py`. +### For the JRDB Challenge Dataset +```python jrdb_preprocess_train.py --input_path= --output_path= --max_distance_to_robot=50.0``` -Set `` in `jrdb_preprocess_train.py` and `jrdb_preprocess_test.py` to where you want to store the processed tensorflow dataset. +```python jrdb_preprocess_test.py --input_path= --output_path= --max_distance_to_robot=50.0 --tracking_method=PiFeNet --tracking_confidence_threshold=0.01``` + +Please note that this can take multiple hours due to the processing of the scene's +pointclouds. If you do not need the pointclouds you can speed up the processing +by passing `--process_pointclouds=False` for both. -```python jrdb_preprocess_train.py``` +### For the Orginal Dataset used in the Paper +```python jrdb_preprocess_train.py --input_path= --output_path= --max_distance_to_robot=15.0``` -```python jrdb_preprocess_test.py``` +```python jrdb_preprocess_test.py --input_path= --output_path= --max_distance_to_robot=15.0 --tracking_method=ss3d_mot``` Please note that this can take multiple hours due to the processing of the scene's pointclouds. If you do not need the pointclouds you can speed up the processing -by setting `POINTCLOUD=False` in both files. \ No newline at end of file +by passing `--process_pointclouds=False` for both. \ No newline at end of file diff --git a/human_scene_transformer/data/jrdb_preprocess_test.py b/human_scene_transformer/data/jrdb_preprocess_test.py index 22a8ffd..e61de9a 100644 --- a/human_scene_transformer/data/jrdb_preprocess_test.py +++ b/human_scene_transformer/data/jrdb_preprocess_test.py @@ -17,16 +17,54 @@ import os +from absl import app +from absl import flags + from human_scene_transformer.data import utils import numpy as np import pandas as pd import tensorflow as tf import tqdm -INPUT_PATH = '' -OUTPUT_PATH = '' -POINTCLOUD = True +_INPUT_PATH = flags.DEFINE_string( + 'input_path', + default=None, + help='Path to jrdb2022 dataset.' +) + +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', + default=None, + help='Path to output folder.' +) + +_PROCESS_POINTCLOUDS = flags.DEFINE_bool( + 'process_pointclouds', + default=True, + help='Whether to process pointclouds.' +) + +_MAX_DISTANCE_TO_ROBOT = flags.DEFINE_float( + 'max_distance_to_robot', + default=15., + help=('Maximum distance of agent to the robot to be included' + ' in the processed dataset.') +) + +_TRACKING_METHOD = flags.DEFINE_string( + 'tracking_method', + default='ss3d_mot', + help='Name of tracking method to use.' +) + +_TRACKING_CONFIDENCE_THRESHOLD = flags.DEFINE_float( + 'tracking_confidence_threshold', + default=.0, + help=('Confidence threshold for tracked agent instance to be included' + ' in the processed dataset.') +) + AGENT_KEYPOINTS = True FROM_DETECTIONS = True @@ -63,7 +101,8 @@ def get_agents_features_df_with_box( ] scene_data_file = utils.get_file_handle( os.path.join( - input_path, 'labels', 'raw_leaderboard', f'{scene_id:04}' + '.txt' + input_path, 'labels', _TRACKING_METHOD.value, + f'{scene_id:04}' + '.txt' ) ) df = pd.read_csv(scene_data_file, sep=' ', names=jrdb_header) @@ -73,7 +112,7 @@ def camera_to_lower_velodyne(p): [p[..., 2], -p[..., 0], -p[..., 1] + (0.742092 - 0.606982)], axis=-1 ) - df = df[df['score'] >= 0.01] + df = df[df['score'] >= _TRACKING_CONFIDENCE_THRESHOLD.value] df['p'] = df[['x', 'y', 'z']].apply( lambda s: camera_to_lower_velodyne(s.to_numpy()), axis=1 @@ -95,6 +134,10 @@ def camera_to_lower_velodyne(p): def jrdb_preprocess_test(input_path, output_path): + """Preprocesses the raw test split of JRDB.""" + + tf.keras.utils.set_random_seed(123) + scenes = list_test_scenes(os.path.join(input_path, 'test_dataset')) subsample = 1 for scene in tqdm.tqdm(scenes): @@ -102,17 +145,18 @@ def jrdb_preprocess_test(input_path, output_path): agents_df = get_agents_features_df_with_box( os.path.join(input_path, 'test_dataset'), scenes.index(scene), - max_distance_to_robot=15.0, + max_distance_to_robot=_MAX_DISTANCE_TO_ROBOT.value, ) robot_odom = utils.get_robot( - os.path.join(input_path, 'processed', 'odometry_test'), scene + os.path.join(input_path, 'processed', 'odometry', 'test'), scene ) if AGENT_KEYPOINTS: keypoints = utils.get_agents_keypoints( os.path.join( - input_path, 'processed', 'labels', 'labels_3d_keypoints_test' + input_path, 'processed', 'labels', + 'labels_3d_keypoints', 'test', _TRACKING_METHOD.value ), scene, ) @@ -202,7 +246,7 @@ def jrdb_preprocess_test(input_path, output_path): os.path.join(output_path, scene_save_name, 'robot', 'orientation') ) - if POINTCLOUD: + if _PROCESS_POINTCLOUDS.value: scene_pointcloud_dict = utils.get_scene_poinclouds( os.path.join(input_path, 'test_dataset'), scene, subsample=subsample ) @@ -231,5 +275,14 @@ def jrdb_preprocess_test(input_path, output_path): compression='GZIP', ) + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + jrdb_preprocess_test(_INPUT_PATH.value, _OUTPUT_PATH.value) + if __name__ == '__main__': - jrdb_preprocess_test(INPUT_PATH, OUTPUT_PATH) + flags.mark_flags_as_required([ + 'input_path', 'output_path' + ]) + app.run(main) diff --git a/human_scene_transformer/data/jrdb_preprocess_train.py b/human_scene_transformer/data/jrdb_preprocess_train.py index f75015f..4397906 100644 --- a/human_scene_transformer/data/jrdb_preprocess_train.py +++ b/human_scene_transformer/data/jrdb_preprocess_train.py @@ -19,16 +19,41 @@ import json import os +from absl import app +from absl import flags + from human_scene_transformer.data import utils import numpy as np import pandas as pd import tensorflow as tf import tqdm -INPUT_PATH = '' -OUTPUT_PATH = '' +_INPUT_PATH = flags.DEFINE_string( + 'input_path', + default=None, + help='Path to jrdb2022 dataset.' +) + +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', + default=None, + help='Path to output folder.' +) + +_PROCESS_POINTCLOUDS = flags.DEFINE_bool( + 'process_pointclouds', + default=True, + help='Whether to process pointclouds.' +) + +_MAX_DISTANCE_TO_ROBOT = flags.DEFINE_float( + 'max_distance_to_robot', + default=15., + help=('Maximum distance of agent to the robot to be included' + ' in the processed dataset.') +) + -POINTCLOUD = True AGENT_KEYPOINTS = True FROM_DETECTIONS = True @@ -88,6 +113,8 @@ def get_agents_features(agents_dict, max_distance_to_robot=10): def jrdb_preprocess_train(input_path, output_path): """Preprocesses the raw train split of JRDB.""" + tf.keras.utils.set_random_seed(123) + subsample = 1 scenes = utils.list_scenes( @@ -104,11 +131,11 @@ def jrdb_preprocess_train(input_path, output_path): ) agents_features = utils.get_agents_features_with_box( - agents_dict, max_distance_to_robot=15.0 + agents_dict, max_distance_to_robot=_MAX_DISTANCE_TO_ROBOT.value ) robot_odom = utils.get_robot( - os.path.join(input_path, 'processed', 'odometry_train'), scene + os.path.join(input_path, 'processed', 'odometry', 'train'), scene ) agents_df = pd.DataFrame.from_dict( @@ -118,7 +145,8 @@ def jrdb_preprocess_train(input_path, output_path): if AGENT_KEYPOINTS: keypoints = utils.get_agents_keypoints( os.path.join( - input_path, 'processed', 'labels', 'labels_3d_keypoints_train'), + input_path, 'processed', 'labels', + 'labels_3d_keypoints', 'train'), scene, ) keypoints_df = pd.DataFrame.from_dict( @@ -207,7 +235,7 @@ def jrdb_preprocess_train(input_path, output_path): os.path.join(output_path, scene, 'robot', 'orientation') ) - if POINTCLOUD: + if _PROCESS_POINTCLOUDS.value: scene_pointcloud_dict = utils.get_scene_poinclouds( os.path.join(input_path, 'train_dataset'), scene, @@ -238,5 +266,14 @@ def jrdb_preprocess_train(input_path, output_path): os.path.join(output_path, scene, 'scene', 'pc'), compression='GZIP' ) + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + jrdb_preprocess_train(_INPUT_PATH.value, _OUTPUT_PATH.value) + if __name__ == '__main__': - jrdb_preprocess_train(INPUT_PATH, OUTPUT_PATH) + flags.mark_flags_as_required([ + 'input_path', 'output_path' + ]) + app.run(main) diff --git a/human_scene_transformer/data/jrdb_train_detections_to_tracks.py b/human_scene_transformer/data/jrdb_train_detections_to_tracks.py index 0b33a80..761fad7 100644 --- a/human_scene_transformer/data/jrdb_train_detections_to_tracks.py +++ b/human_scene_transformer/data/jrdb_train_detections_to_tracks.py @@ -19,6 +19,9 @@ import json import os +from absl import app +from absl import flags + from human_scene_transformer.data import box_utils from human_scene_transformer.data import utils import numpy as np @@ -28,9 +31,17 @@ import tqdm -INPUT_PATH = '' -OUTPUT_PATH = os.path.join( - INPUT_PATH, 'processed/labels/labels_detections_3d') +_INPUT_PATH = flags.DEFINE_string( + 'input_path', + default=None, + help='Path to jrdb2022 dataset.' +) + +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', + default=None, + help='Path to output folder.' +) def get_agents_3d_bounding_box_dict(input_path, scene): @@ -176,5 +187,17 @@ def jrdb_train_detections_to_tracks(input_path, output_path): with open(f"{output_path}/{scene}.json", 'w') as write_file: json.dump(labels_dict, write_file, indent=2, ensure_ascii=True) + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + if _OUTPUT_PATH.value is None: + output_path = os.path.join(_INPUT_PATH.value, + 'processed/labels/labels_detections_3d') + else: + output_path = _OUTPUT_PATH.value + jrdb_train_detections_to_tracks(_INPUT_PATH.value, output_path) + if __name__ == '__main__': - jrdb_train_detections_to_tracks(INPUT_PATH, OUTPUT_PATH) + flags.mark_flags_as_required(['input_path']) + app.run(main) diff --git a/human_scene_transformer/jrdb/eval_challenge.py b/human_scene_transformer/jrdb/eval_challenge.py index a24c1e5..e1c479d 100644 --- a/human_scene_transformer/jrdb/eval_challenge.py +++ b/human_scene_transformer/jrdb/eval_challenge.py @@ -27,7 +27,6 @@ from human_scene_transformer.jrdb import input_fn from human_scene_transformer.model import model as hst_model from human_scene_transformer.model import model_params -from human_scene_transformer.model import scene_encoder # pylint: disable=unused-import import pandas as pd import tensorflow as tf import tqdm @@ -66,26 +65,26 @@ _MODEL_PATH = flags.DEFINE_string( 'model_path', - '', - 'Path to model directory.', + None, + help='Path to model directory.', ) _CKPT_PATH = flags.DEFINE_string( 'checkpoint_path', - '', - 'Path to model checkpoint.', + None, + help='Path to model checkpoint.', ) _DATASET_PATH = flags.DEFINE_string( 'dataset_path', - '', - 'Path to model checkpoint.', + None, + help='Path to model checkpoint.', ) _OUTPUT_PATH = flags.DEFINE_string( 'output_path', - '', - 'Path to output.', + None, + help='Path to output.', ) @@ -374,5 +373,8 @@ def main(argv: Sequence[str]) -> None: if __name__ == '__main__': + flags.mark_flags_as_required([ + 'model_path', 'checkpoint_path', 'dataset_path', 'output_path' + ]) logging.set_verbosity(logging.ERROR) app.run(main) diff --git a/human_scene_transformer/model/model_params.py b/human_scene_transformer/model/model_params.py index 8edb0c2..525acbc 100644 --- a/human_scene_transformer/model/model_params.py +++ b/human_scene_transformer/model/model_params.py @@ -22,6 +22,7 @@ from human_scene_transformer import is_hidden_generators from human_scene_transformer.model import agent_feature_encoder from human_scene_transformer.model import head +from human_scene_transformer.model import scene_encoder as _ @gin.configurable