Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tensordict.TensorDict and tensordict.nn.make_tensordict can't handle dictionaries with non-string keys #746

Open
3 tasks done
Bhartendu-Kumar opened this issue Apr 24, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@Bhartendu-Kumar
Copy link

Bhartendu-Kumar commented Apr 24, 2024

Describe the bug

The functions: TensorDict and tensordict.nn.make_tensordict expects a dictionary to be passed.
a dictionary with non-string keys gives an error: IndexError: tuple index out of range

Same is true about tensordict.TensorDict function.

To Reproduce

from tensordict import TensorDict
d = {1: torch.randn(2), 2: torch.randn(2)}
d = TensorDict(d, batch_size=2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
    self.set(key, value, non_blocking=non_blocking)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
    return self._set_tuple(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
    td = self._get_str(key[0], None)
IndexError: tuple index out of range
from tensordict.nn import make_tensordict
d = {1: torch.randn(2), 2: torch.randn(2)}
d = make_tensordict(d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
    self.set(key, value, non_blocking=non_blocking)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
    return self._set_tuple(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
    td = self._get_str(key[0], None)
IndexError: tuple index out of range
>>> from tensordict.nn import make_tensordict
>>> d = make_tensordict(d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/functional.py", line 379, in make_tensordict
    return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1332, in from_dict
    out = cls(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
    self.set(key, value, non_blocking=non_blocking)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
    return self._set_tuple(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
    td = self._get_str(key[0], None)
IndexError: tuple index out of range

Expected behavior

when the dictionary has string keys, a python dictionary is converted to TensorDict ,
eg.
d = {"1": torch.randn(2), "2": torch.randn(2)} d = TensorDict(d, batch_size=2)

This is correct code as expected but, when keys are non-string like

d = {1: torch.randn(2), 2: torch.randn(2)} d = TensorDict(d, batch_size=2)

it gives an error.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...): python -m pip install tensordict==0.3.2
  • Python version: Python 3.8.13
  • Versions of any other relevant libraries: pytorch:2.2.2+cu121
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.3.2 1.22.4 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)
[GCC 10.3.0] linux 2.2.2+cu121

Additional context

Reason and Possible fixes

I think the code at an abstract level works in 2 steps:

  1. Step 1: Get the length of keys of the given input dictionary
  2. Step 2: Get the string keys and construct tensordict object from these keys

Thus, the culprit might

tensordict/_td.py:1615), in TensorDict._set_tuple(self, key, value, inplace, validated, non_blocking)
    if len(key) == 1:
           return self._set_str(

which calls

td = self._get_str(key[0], None)

So whats happening is search for string keys, where keys might not be string

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@Bhartendu-Kumar Bhartendu-Kumar added the bug Something isn't working label Apr 24, 2024
@vmoens
Copy link
Contributor

vmoens commented Apr 24, 2024

Hello
Thanks for posting this!

TensorDict required keys to be strings, tuples of strings or tuples of tuples of strings etc. but no other key type is allowed.

The main reason is that tensordicts can also be indexed along the "shape" dimension, and allowing other key-types (e.g. ints) would lead to undefined behaviours.
Example

data = TensorDict({"a": torch.arange(3)}, batch_size=[3])
data[1] # returns 1
data = TensorDict({1: torch.arange(3)}, batch_size=[3])
data[1] # should this take the second element along shape dimension, or the '1' key?

That being said we should probably capture this error to make things clearer for our users!

Hope that helps

@Bhartendu-Kumar
Copy link
Author

Oh!
Makes sense. Thanks for the reply.
But still the error :

IndexError: tuple index out of range

does not seem verbose enough to know that the conflict is with the dictionary key types.

So I think this check should be there and printing the appropriate error message about expected dictionary than index out of range.

@Bhartendu-Kumar
Copy link
Author

Because earlier the values of the keys were anything different than tensordict, dictionary, scalars and tensors, it explicitly gave the error that data type of value is out of this set.

So I think something similar for keys be beneficial.

Should I go ahead and add the type checking for this, if you confirm that the keys would be just string, tuple of string, so on.
Thanks

@vmoens
Copy link
Contributor

vmoens commented Jun 27, 2024

I think #826 could be a workaround (allows you to store data as tensordicts using any kind of key - even another tensordict)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants