Skip to content

Commit

Permalink
Update transformer.py (#210)
Browse files Browse the repository at this point in the history
* Update transformer.py

I get a numpy cuda / cpu conversion error due to the `batches` being moved to `device` before the `collate()` function.

```
Traceback (most recent call last):
  File "/local-scratch-nvme/nigam/ehrshot/lib/python3.10/pdb.py", line 1723, in main
    pdb._runscript(mainpyfile)
  File "/local-scratch-nvme/nigam/ehrshot/lib/python3.10/pdb.py", line 1583, in _runscript
    self.run(statement)
  File "/local-scratch-nvme/nigam/ehrshot/lib/python3.10/bdb.py", line 598, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File "/share/pi/nigam/mwornow/ehrshot-benchmark/ehrshot/4_generate_clmbr_features.py", line 123, in <module>
    results: Dict[str, Any] = compute_features(dataset, model_name, labels, ontology=None, num_proc=num_threads, tokens_per_batch=tokens_per_batch, device=device)
  File "/share/pi/nigam/mwornow/ehrshot-benchmark/ehrshot/4_generate_clmbr_features.py", line 69, in compute_features
    batch = processor.collate([batch])["batch"]
  File "/local-scratch-nvme/nigam/ehrshot/lib/python3.10/site-packages/femr/models/processor.py", line 401, in collate
    return {"batch": _add_dimension(self.creator.cleanup_batch(batches[0]))}
  File "/local-scratch-nvme/nigam/ehrshot/lib/python3.10/site-packages/femr/models/processor.py", line 322, in cleanup_batch
    batch["transformer"]["patient_lengths"] = np.array(batch["transformer"]["patient_lengths"])
  File "/local-scratch-nvme/nigam/ehrshot/lib/python3.10/site-packages/torch/_tensor.py", line 1062, in __array__
    return self.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
```

This fixes that by waiting until after `collate()` to move to device.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Miking98 and pre-commit-ci[bot] authored Apr 25, 2024
1 parent dd1baba commit 0df7121
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/femr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,23 @@ def compute_features(
filtered_data, tokens_per_batch=tokens_per_batch, min_patients_per_batch=1, num_proc=num_proc
)

batches.set_format("pt", device=device)
batches.set_format("pt")

all_patient_ids = []
all_feature_times = []
all_representations = []

for batch in tqdm(batches, total=len(batches)):
batch = processor.collate([batch])["batch"]

# Move to device
for key, val in batch.items():
if isinstance(val, torch.Tensor):
batch[key] = batch[key].to(device)
for key, val in batch["transformer"].items():
if isinstance(val, torch.Tensor):
batch["transformer"][key] = batch["transformer"][key].to(device)

with torch.no_grad():
_, result = model(batch, return_reprs=True)
all_patient_ids.append(result["patient_ids"].cpu().numpy())
Expand Down

0 comments on commit 0df7121

Please sign in to comment.