Skip to content

Commit

Permalink
debug scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Nov 26, 2023
1 parent ed490ec commit d335f06
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/full_gpu_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on: # yamllint disable-line rule:truthy
workflow_dispatch:
schedule:
- cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST
pull_request:

jobs:

Expand All @@ -29,7 +30,7 @@ jobs:
pip install -e .[full,test]
- name: Run tests
timeout-minutes: 20
timeout-minutes: 200
run: |
FULL_TEST=1 pytest
shell: bash
3 changes: 2 additions & 1 deletion .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on: # yamllint disable-line rule:truthy
workflow_dispatch:
schedule:
- cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST
pull_request:

jobs:

Expand Down Expand Up @@ -50,7 +51,7 @@ jobs:
pip install -e .[full,test]
- name: Run tests
timeout-minutes: 20
timeout-minutes: 200
run: |
FULL_TEST=1 pytest --cov --cov-report=xml
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Run tests
if: steps.changed-files-specific.outputs.only_changed != 'true'
timeout-minutes: 10
timeout-minutes: 100
run: |
pytest --cov --cov-report=xml --durations 10
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
WITH_SEGMM = False
WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add')
WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr')
WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
WITH_INDEX_SORT = False
WITH_METIS = hasattr(pyg_lib, 'partition')
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
Expand All @@ -64,7 +64,7 @@

try:
import torch_scatter # noqa
WITH_TORCH_SCATTER = True
WITH_TORCH_SCATTER = False
except Exception as e:
if not isinstance(e, ImportError): # pragma: no cover
warnings.warn(f"An issue occurred while importing 'torch-scatter'. "
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/utils/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def scatter(src: Tensor, index: Tensor, dim: int = 0,
f" package, but it was not found")

index = broadcast(index, src, dim)
return src.new_zeros(size).scatter_reduce_(
dim, index, src, reduce=f'a{reduce}', include_self=False)
return _scatter_min_or_max(src, index, dim, size, reduce)

return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
reduce=reduce)
Expand All @@ -117,6 +116,13 @@ def scatter(src: Tensor, index: Tensor, dim: int = 0,

raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")


@torch._dynamo.optimize()
def _scatter_min_or_max(src: Tensor, index: Tensor, dim: int,
size: int, reduce: str):
return src.new_zeros(size).scatter_reduce_(
dim, index, src, reduce=f'a{reduce}', include_self=False)

else: # pragma: no cover

def scatter(src: Tensor, index: Tensor, dim: int = 0,
Expand Down

0 comments on commit d335f06

Please sign in to comment.