-
Notifications
You must be signed in to change notification settings - Fork 343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Text generation with an RNN - error in the model class #1268
Comments
The error in your code lies in the way you are calling the superclass's Change this line: super().__init__(self) to this: super().__init__() So, your corrected class definition should look like this: class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x |
Even if this is done, an error is produced further down in the code. Right after creating an object of this class, there's this code in the tutorial:
which procudes the following error:
This was produced in Google Colab with TensorFlow 2.17. I don't really know how to solve this. It would be nice if the TensorFlow team updated this tutorial with working code. |
After some research, I have found 4 different and independent errors with the code present in the tutorial. I hope these will help, and with the last two in particular, I hope other people can find a proper solution and integrate a fix into Tensorflow. Error found by @808vitaSymptom
CauseA mistake in the Python code related to Python itself, not Keras or Tensorflow. When calling Incidentally, this mistake may have been present right from the start, with the Python version with which the tutorial was written. SolutionReplace super().__init__(self) With super().__init__() Error found by @alexdrymonitisSymptom
This error needs an update to the Text generation with an RNN tutorial. CauseBreaking change in the Keras API introduced in Tensorflow version 2.16 (according to the documentation). Starting from Tensorflow version 2.16, the method The tutorial hasn't been updated to reflect the API change. SolutionReplace states = self.gru.get_initial_state(x) With states = self.gru.get_initial_state(tf.shape(x)[0]) This error needs an update to the Text generation with an RNN tutorial. Error found by meSymptomWhen running on a GPU, but not on a CPU. Using tensorflow==2.17.0 and keras==3.5.0 on Python 3.11, Cuda 12.6.1 and Cudnn 8.9.7.
CauseFor unknown reasons, I assume this behavior is a bug in the RNN implementation on GPU (since it is also present with LSTM). WorkaroundReplace states = self.gru.get_initial_state(x) With r = self.gru(x, initial_state=states, training=training)
x, states = r[0], r[1:] This error needs a fix in the Tensorflow Cuda code. Error found by meSymptomWhen running on a GPU, but not on a CPU. Using tensorflow==2.17.0 and keras==3.5.0 on Python 3.11, Cuda 12.6.1 and Cudnn 8.9.7. This error only happens when training, not when running the model over a tensor extracted from the dataset.
CauseUnknown. I am still investigating, so far without result. Like the error above, I assume it is a bug in the RNN or GRU Cuda code. Most likely related to GRU and not RNN in general since this issue does not happen with the SimpleRNN or LSTM layers. Solution or workaroundNone found so far. I really hope someone finds one because it's blocking me from working on a related project. |
TypeError: Layer.init() takes 1 positional argument but 2 were given
The text was updated successfully, but these errors were encountered: