Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

step #260

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open

step #260

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
.idea
.vscode

*.o
*.so
*.cpp
*.egg-info

__pycache__/

data/
results/
test/
*log*
176 changes: 0 additions & 176 deletions LICENSE

This file was deleted.

205 changes: 26 additions & 179 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,180 +1,27 @@
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/191111236/semantic-segmentation-on-semantic3d)](https://paperswithcode.com/sota/semantic-segmentation-on-semantic3d?p=191111236)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/191111236/3d-semantic-segmentation-on-semantickitti)](https://paperswithcode.com/sota/3d-semantic-segmentation-on-semantickitti?p=191111236)
[![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)

# RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds (CVPR 2020)

This is the official implementation of **RandLA-Net** (CVPR2020, Oral presentation), a simple and efficient neural architecture for semantic segmentation of large-scale 3D point clouds. For technical details, please refer to:

**RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds** <br />
[Qingyong Hu](https://www.cs.ox.ac.uk/people/qingyong.hu/), [Bo Yang*](https://yang7879.github.io/), [Linhai Xie](https://www.cs.ox.ac.uk/people/linhai.xie/), [Stefano Rosa](https://www.cs.ox.ac.uk/people/stefano.rosa/), [Yulan Guo](http://yulanguo.me/), [Zhihua Wang](https://www.cs.ox.ac.uk/people/zhihua.wang/), [Niki Trigoni](https://www.cs.ox.ac.uk/people/niki.trigoni/), [Andrew Markham](https://www.cs.ox.ac.uk/people/andrew.markham/). <br />
**[[Paper](https://arxiv.org/abs/1911.11236)] [[Video](https://youtu.be/Ar3eY_lwzMk)] [[Blog](https://zhuanlan.zhihu.com/p/105433460)] [[Project page](http://randla-net.cs.ox.ac.uk/)]** <br />


<p align="center"> <img src="http://randla-net.cs.ox.ac.uk/imgs/Fig3.png" width="100%"> </p>



### (1) Setup
This code has been tested with Python 3.5, Tensorflow 1.11, CUDA 9.0 and cuDNN 7.4.1 on Ubuntu 16.04.

- Clone the repository
```
git clone --depth=1 https://github.com/QingyongHu/RandLA-Net && cd RandLA-Net
```
- Setup python environment
```
conda create -n randlanet python=3.5
source activate randlanet
pip install -r helper_requirements.txt
sh compile_op.sh
```

**Update 03/21/2020, pre-trained models and results are available now.**
You can download the pre-trained models and results [here](https://drive.google.com/open?id=1iU8yviO3TP87-IexBXsu13g6NklwEkXB).
Note that, please specify the model path in the main function (e.g., `main_S3DIS.py`) if you want to use the pre-trained model and have a quick try of our RandLA-Net.

### (2) S3DIS
S3DIS dataset can be found
<a href="https://docs.google.com/forms/d/e/1FAIpQLScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1">here</a>.
Download the files named "Stanford3dDataset_v1.2_Aligned_Version.zip". Uncompress the folder and move it to
`/data/S3DIS`.

- Preparing the dataset:
```
python utils/data_prepare_s3dis.py
```
- Start 6-fold cross validation:
```
sh jobs_6_fold_cv_s3dis.sh
```
- Move all the generated results (*.ply) in `/test` folder to `/data/S3DIS/results`, calculate the final mean IoU results:
```
python utils/6_fold_cv.py
```

Quantitative results of different approaches on S3DIS dataset (6-fold cross-validation):

![a](http://randla-net.cs.ox.ac.uk/imgs/S3DIS_table.png)

Qualitative results of our RandLA-Net:

| ![2](imgs/S3DIS_area2.gif) | ![z](imgs/S3DIS_area3.gif) |
| ------------------------------ | ---------------------------- |



### (3) Semantic3D
7zip is required to uncompress the raw data in this dataset, to install p7zip:
```
sudo apt-get install p7zip-full
```
- Download and extract the dataset. First, please specify the path of the dataset by changing the `BASE_DIR` in "download_semantic3d.sh"
```
sh utils/download_semantic3d.sh
```
- Preparing the dataset:
```
python utils/data_prepare_semantic3d.py
```
- Start training:
```
python main_Semantic3D.py --mode train --gpu 0
```
- Evaluation:
```
python main_Semantic3D.py --mode test --gpu 0
```
Quantitative results of different approaches on Semantic3D (reduced-8):

![a](http://randla-net.cs.ox.ac.uk/imgs/Semantic3D_table.png)

Qualitative results of our RandLA-Net:

| ![z](imgs/Semantic3D-1.gif) | ![z](http://randla-net.cs.ox.ac.uk/imgs/Semantic3D-2.gif) |
| -------------------------------- | ------------------------------- |
| ![z](imgs/Semantic3D-3.gif) | ![z](imgs/Semantic3D-4.gif) |



**Note:**
- Preferably with more than 64G RAM to process this dataset due to the large volume of point cloud


### (4) SemanticKITTI

SemanticKITTI dataset can be found <a href="http://semantic-kitti.org/dataset.html#download">here</a>. Download the files
related to semantic segmentation and extract everything into the same folder. Uncompress the folder and move it to
`/data/semantic_kitti/dataset`.

- Preparing the dataset:
```
python utils/data_prepare_semantickitti.py
```

- Start training:
```
python main_SemanticKITTI.py --mode train --gpu 0
```

- Evaluation:
```
sh jobs_test_semantickitti.sh
```

Quantitative results of different approaches on SemanticKITTI dataset:

![s](http://randla-net.cs.ox.ac.uk/imgs/SemanticKITTI_table.png)

Qualitative results of our RandLA-Net:

![zzz](imgs/SemanticKITTI-2.gif)


### (5) Demo

<p align="center"> <a href="https://youtu.be/Ar3eY_lwzMk"><img src="http://randla-net.cs.ox.ac.uk/imgs/demo_cover.png" width="80%"></a> </p>


### Citation
If you find our work useful in your research, please consider citing:

@article{hu2019randla,
title={RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds},
author={Hu, Qingyong and Yang, Bo and Xie, Linhai and Rosa, Stefano and Guo, Yulan and Wang, Zhihua and Trigoni, Niki and Markham, Andrew},
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}

@article{hu2021learning,
title={Learning Semantic Segmentation of Large-Scale Point Clouds with Random Sampling},
author={Hu, Qingyong and Yang, Bo and Xie, Linhai and Rosa, Stefano and Guo, Yulan and Wang, Zhihua and Trigoni, Niki and Markham, Andrew},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2021},
publisher={IEEE}
}


### Acknowledgment
- Part of our code refers to <a href="https://github.com/jlblancoc/nanoflann">nanoflann</a> library and the the recent work <a href="https://github.com/HuguesTHOMAS/KPConv">KPConv</a>.
- We use <a href="https://www.blender.org/">blender</a> to make the video demo.


### License
Licensed under the CC BY-NC-SA 4.0 license, see [LICENSE](./LICENSE).


### Updates
* 21/03/2020: Updating all experimental results
* 21/03/2020: Adding pretrained models and results
* 02/03/2020: Code available!
* 15/11/2019: Initial release!

## Related Repos
1. [SoTA-Point-Cloud: Deep Learning for 3D Point Clouds: A Survey](https://github.com/QingyongHu/SoTA-Point-Cloud) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SoTA-Point-Cloud.svg?style=flat&label=Star)
2. [SensatUrban: Learning Semantics from Urban-Scale Photogrammetric Point Clouds](https://github.com/QingyongHu/SpinNet) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SensatUrban.svg?style=flat&label=Star)
3. [3D-BoNet: Learning Object Bounding Boxes for 3D Instance Segmentation on Point Clouds](https://github.com/Yang7879/3D-BoNet) ![GitHub stars](https://img.shields.io/github/stars/Yang7879/3D-BoNet.svg?style=flat&label=Star)
4. [SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration](https://github.com/QingyongHu/SpinNet) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SpinNet.svg?style=flat&label=Star)
5. [SQN: Weakly-Supervised Semantic Segmentation of Large-Scale 3D Point Clouds with 1000x Fewer Labels](https://github.com/QingyongHu/SQN) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SQN.svg?style=flat&label=Star)

# Toronto-3D and OpenGF dataset code for RandLA-Net

Code for [Toronto-3D](https://github.com/WeikaiTan/Toronto-3D.git) has been uploaded. Try it for building your own network.

Will release code for OpenGF later

## Train and test RandLA-Net on Toronto-3D
1. Set up environment and compile the operations - exactly the same as the RandLA-Net environment
1. Create a folder called `data` and move the `.ply` files into `data/Tronto_3D/original_ply/`
1. Change parameters according to your preference in `data_prepare_toronto3d.py` and run to preprocess point clouds
1. Change parameters according to your preference in `helper_tool.ply` to build the network
1. Train the network by running `python main_Toronto3D.py --mode train`
1. Test and evaluate on `L002` by running `python main_Toronto3D.py --mode test --test_eval True`
1. Modify the code to find a good parameter set or test on your own data

## Sample results of Toronto-3D
The highest results reported are from RandLA-Net in [Hu et al. (2021)](https://doi.org/10.1109/TPAMI.2021.3083288). Here are some results I got on my code with the default parameters. The largest factor in mIoU is the accuracy of *Road Markings*, which is impossible to be classified with XYZ only.

| Features | OA | mIoU | Road | Road mrk. | Natural | Bldg | Util. line | Pole | Car | Fence |
|----------|----|------|------|-----------|---------|------|------------|------|-----|-------|
| [Hu et al. (2021)](https://doi.org/10.1109/TPAMI.2021.3083288) | 92.95 | 77.71 | 94.61 | 42.62 | 96.89 | 93.01 | 86.51 | 78.07 | 92.85 | 37.12 |
| [Hu et al. (2021)](https://doi.org/10.1109/TPAMI.2021.3083288) with RGB| 94.37 | 81.77 | 96.69 | 64.21 | 96.92 | 94.24 | 88.06 | 77.84 | 93.37 | 42.86 |
|XYZRGBI | 96.57 | 81.00 | 95.61 | 58.04 | 97.22 | 93.45 | 87.58 | 82.64 | 91.06 | 42.39 |
|XYZRGB | 96.71 | 80.89 | 95.88 | 60.75 | 97.02 | 94.04 | 86.71 | 83.30 | 87.66 | 41.80 |
|XYZI | 95.65 | 80.03 | 94.59 | 50.14 | 95.90 | 92.76 | 87.70 | 77.77 | 91.10 | 50.30 |
|XYZ | 94.94 | 74.13 | 93.53 | 12.52 | 96.67 | 92.34 | 86.25 | 80.10 | 88.04 | 43.57 |

2 changes: 2 additions & 0 deletions RandLANet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import helper_tf_util
import time

import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

def log_out(out_str, f_out):
f_out.write(out_str + '\n')
Expand Down
54 changes: 54 additions & 0 deletions data_prepare_toronto3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from sklearn.neighbors import KDTree
from os.path import join, exists, dirname, abspath
import numpy as np
import os, pickle
import sys

BASE_DIR = dirname(abspath(__file__))
ROOT_DIR = dirname(BASE_DIR)
sys.path.append(BASE_DIR)
sys.path.append(ROOT_DIR)
from helper_ply import write_ply, read_ply
from helper_tool import DataProcessing as DP

grid_size = 0.06
dataset_path = 'data/Toronto_3D'
train_files = ['L001', 'L003', 'L004']
val_files = ['L002']
UTM_OFFSET = [627285, 4841948, 0]
original_pc_folder = join(dataset_path, 'original_ply')
sub_pc_folder = join(dataset_path, 'input_{:.3f}'.format(grid_size))
os.mkdir(sub_pc_folder) if not exists(sub_pc_folder) else None

for pc_path in [join(original_pc_folder, fname + '.ply') for fname in train_files + val_files]:
print(pc_path)
file_name = pc_path.split('/')[-1][:-4]

pc = read_ply(pc_path)
labels = pc['scalar_Label'].astype(np.uint8)
xyz = np.vstack((pc['x'] - UTM_OFFSET[0], pc['y'] - UTM_OFFSET[1], pc['z'] - UTM_OFFSET[2])).T.astype(np.float32)
color = np.vstack((pc['red'], pc['green'], pc['blue'])).T.astype(np.uint8)
intensity = pc['scalar_Intensity'].astype(np.uint8).reshape(-1,1)
#  Subsample to save space
sub_xyz, sub_colors, sub_labels = DP.grid_sub_sampling(xyz, color, labels, grid_size)
_, sub_intensity = DP.grid_sub_sampling(xyz, features=intensity, grid_size=grid_size)

sub_colors = sub_colors / 255.0
sub_intensity = sub_intensity[:,0] / 255.0
sub_ply_file = join(sub_pc_folder, file_name + '.ply')
write_ply(sub_ply_file, [sub_xyz, sub_colors, sub_intensity, sub_labels], ['x', 'y', 'z', 'red', 'green', 'blue', 'intensity', 'class'])

search_tree = KDTree(sub_xyz, leaf_size=50)
kd_tree_file = join(sub_pc_folder, file_name + '_KDTree.pkl')
with open(kd_tree_file, 'wb') as f:
pickle.dump(search_tree, f)

if file_name not in train_files:
proj_idx = np.squeeze(search_tree.query(xyz, return_distance=False))
proj_idx = proj_idx.astype(np.int32)
proj_save = join(sub_pc_folder, file_name + '_proj.pkl')
with open(proj_save, 'wb') as f:
pickle.dump([proj_idx, labels], f)



37 changes: 35 additions & 2 deletions helper_tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from open3d import linux as open3d
import open3d
from os.path import join
import numpy as np
import colorsys, random, os, sys
Expand All @@ -14,7 +14,38 @@
import cpp_wrappers.cpp_subsampling.grid_subsampling as cpp_subsampling
import nearest_neighbors.lib.python.nearest_neighbors as nearest_neighbors

class ConfigToronto3D:
k_n = 16 # KNN
num_layers = 5 # Number of layers
num_points = 65536 # Number of input points
num_classes = 8 # Number of valid classes
sub_grid_size = 0.06 # preprocess_parameter
use_rgb = False # Use RGB
use_intensity = False # Use intensity

batch_size = 4 # batch_size during training
val_batch_size = 14 # batch_size during validation and test
train_steps = 500 # Number of steps per epochs
val_steps = 25 # Number of validation steps per epoch

sub_sampling_ratio = [4, 4, 4, 4, 2] # sampling ratio of random sampling at each layer
d_out = [16, 64, 128, 256, 512] # feature dimension

noise_init = 3.5 # noise initial parameter
max_epoch = 100 # maximum epoch during training
learning_rate = 1e-2 # initial learning rate
lr_decays = {i: 0.95 for i in range(0, 500)} # decay rate of learning rate

train_sum_dir = 'train_log'
saving = True
saving_path = None

augment_scale_anisotropic = True
augment_symmetries = [True, False, False]
augment_rotation = 'vertical'
augment_scale_min = 0.8
augment_scale_max = 1.2
augment_noise = 0.001
class ConfigSemanticKITTI:
k_n = 16 # KNN
num_layers = 4 # Number of layers
Expand Down Expand Up @@ -255,7 +286,9 @@ def get_class_weights(dataset_name):
elif dataset_name is 'SemanticKITTI':
num_per_class = np.array([55437630, 320797, 541736, 2578735, 3274484, 552662, 184064, 78858,
240942562, 17294618, 170599734, 6369672, 230413074, 101130274, 476491114,
9833174, 129609852, 4506626, 1168181])
9833174, 129609852, 4506626, 1168181], dtype=np.int32)
elif dataset_name is 'Toronto3D':
num_per_class = np.array([35391894, 1449308, 4650919, 18252779, 589856, 743579, 4311631, 356463], dtype=np.int32)
weight = num_per_class / float(sum(num_per_class))
ce_label_weight = 1 / (weight + 0.02)
return np.expand_dims(ce_label_weight, axis=0)
Expand Down
Binary file removed imgs/S3DIS_area2.gif
Binary file not shown.
Binary file removed imgs/S3DIS_area3.gif
Binary file not shown.
Binary file removed imgs/Semantic3D-1.gif
Binary file not shown.
Binary file removed imgs/Semantic3D-3.gif
Binary file not shown.
Binary file removed imgs/Semantic3D-4.gif
Binary file not shown.
Binary file removed imgs/SemanticKITTI-2.gif
Binary file not shown.
14 changes: 0 additions & 14 deletions jobs_6_fold_cv_s3dis.sh

This file was deleted.

12 changes: 0 additions & 12 deletions jobs_test_semantickitti.sh

This file was deleted.

Loading