GradienTape hội tụ chậm hơn nhiều so với Keras.model.fit


8

Tôi hiện đang cố gắng để có được một api TF2.0 , nhưng khi tôi so sánh GradientTape với một máy ảnh thông thường.Model.fit tôi nhận thấy:

  1. Nó chạy chậm hơn (có thể là do Thực thi háo hức)

  2. Nó hội tụ chậm hơn nhiều (và tôi không chắc tại sao).

+--------+--------------+--------------+------------------+
|  Epoch | GradientTape | GradientTape | keras.Model.fit  |
|        |              |  shuffling   |                  |
+--------+--------------+--------------+------------------+
|    1   |     0.905    |     0.918    |      0.8793      |
+--------+--------------+--------------+------------------+
|    2   |     0.352    |     0.634    |      0.2226      |
+--------+--------------+--------------+------------------+
|    3   |     0.285    |     0.518    |      0.1192      |
+--------+--------------+--------------+------------------+
|    4   |     0.282    |     0.458    |      0.1029      |
+--------+--------------+--------------+------------------+
|    5   |     0.275    |     0.421    |      0.0940      |
+--------+--------------+--------------+------------------+

Đây là vòng lặp đào tạo tôi đã sử dụng với GradientTape :


optimizer = keras.optimizers.Adam()
glove_model = GloveModel(vocab_size=len(labels))
train_loss = keras.metrics.Mean(name='train_loss')

@tf.function
def train_step(examples, labels):
    with tf.GradientTape() as tape:
        predictions = glove_model(examples)
        loss = glove_model.glove_loss(labels, predictions)

    gradients = tape.gradient(loss, glove_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, glove_model.trainable_variables))

    train_loss(loss)



total_step = 0
for epoch in range(epochs_number):

    pbar = tqdm(train_ds.enumerate(), total=int(len(index_data) / batch_size) + 1)

    for ix, (examples, labels) in pbar:

        train_step(examples, labels)


    print(f"Epoch {epoch + 1}, Loss {train_loss.result()}")

    # Reset the metrics for the next epoch
    train_loss.reset_states()

Và đây là khóa đào tạo Keras.Model.fit :

glove_model.compile(optimizer, glove_model.glove_loss)
glove_model.fit(train_ds, epochs=epochs_number)

Đây là nguồn tf.data.Dataset

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
).shuffle(100000).batch(batch_size, drop_remainder=True)

Và đây là mô hình.

class GloveModel(keras.Model):

    def __init__(self, vocab_size, dim=100, a=3/4, x_max=100):
        super(GloveModel, self).__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.a = a
        self.x_max = x_max

        self.target_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="target_embedding"
        )
        self.target_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="target_bias"
        )

        self.context_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="context_embedding"
        )
        self.context_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="context_bias"
        )

        self.dot_product = layers.Dot(axes=-1, name="dot")

        self.prediction = layers.Add(name="add")
        self.step = 0

    def call(self, inputs):

        target_ix = inputs[:, 0]
        context_ix = inputs[:, 1]

        target_embedding = self.target_embedding(target_ix)
        target_bias = self.target_bias(target_ix)

        context_embedding = self.context_embedding(context_ix)
        context_bias = self.context_bias(context_ix)

        dot_product = self.dot_product([target_embedding, context_embedding])
        prediction = self.prediction([dot_product, target_bias, context_bias])

        return prediction

    def glove_loss(self, y_true, y_pred):

        weight = tf.math.minimum(
            tf.math.pow(y_true/self.x_max, self.a), 1.0
        )
        loss_value = tf.math.reduce_mean(weight * tf.math.pow(y_pred - tf.math.log(y_true), 2.0))

        return loss_value


Tôi đã thử nhiều cấu hình và tối ưu hóa nhưng dường như không có gì thay đổi tốc độ hội tụ.


1
Một điều cần xem xét là xáo trộn dữ liệu trước mỗi kỷ nguyên.
THN

Tôi có chính xác sự xáo trộn giống nhau giữa phương thức fit và GradientTape vì tôi sử dụng api tf.Data.
Benjamin Breton

1
Tôi nghĩ rằng chúng không giống hệt nhau. Bạn có thể hiển thị mã của bạn tfds? Lưu ý rằng máy ảnh .fitmặc định để xáo trộn trước mỗi kỷ nguyên. Bạn có thể kiểm tra bằng cách tắt xáo trộn trong máy ảnh và so sánh tốc độ hội tụ của chúng.
THN

@THN Tôi sẽ gửi nó cho bạn, nhưng tôi đã thực hiện một shuffle với tf.Dataset api để nó không thay đổi gì cả phải không?
Benjamin Breton

@THN Tôi đã thêm tf.data.Dataset
Benjamin Breton

Câu trả lời:


2

Dataset.shuffle()chỉ xáo trộn mỗi xe buýt nhỏ, vì vậy mỗi kỷ nguyên có cùng một thứ tự. Keras .fit()sử dụng một số phép thuật để xáo trộn toàn bộ dữ liệu trước mỗi kỷ nguyên. Để thực hiện điều này trong TF, bạn cần sử dụng Bộ dữ liệu .repeat(epochs_number).shuffle(..., reshuffle_each_iteration=True):

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
    ).shuffle(100000, reshuffle_each_iteration=True
    ).batch(batch_size, drop_remainder=True
    ).repeat(epochs_number)

for ix, (examples, labels) in train_ds.enumerate():
    train_step(examples, labels)
    current_epoch = ix // (len(index_data) // batch_size)

Cách giải quyết này không đẹp cũng không tự nhiên, hiện tại bạn có thể sử dụng điều này để xáo trộn từng kỷ nguyên. Đây là một vấn đề đã biết và sẽ được khắc phục, trong tương lai bạn có thể sử dụng for epoch in range(epochs_number)thay vì .repeat().


Tôi xin lỗi, tôi đã thêm mã của bạn, nhưng sự hội tụ thậm chí còn chậm hơn. Tôi đã thêm các kết quả trong cột GradientTape shuffle. Nó không có ý nghĩa với tôi ...
Benjamin Breton

@BenjaminBreton Tại thời điểm này, tôi nghi ngờ có một số lỗi khác ẩn trong mã của bạn. Có lẽ tốt nhất là liên kết với repo của bạn để hiển thị mã đầy đủ. Nếu bạn chắc chắn các thí nghiệm của mình được tiến hành chính xác, bạn nên mở một vấn đề về repo dòng chảy.
THN

Cảm ơn bạn rất nhiều vì sự giúp đỡ của bạn @THN Tôi đã đăng vấn đề trên repo TF2.0 github.com/tensorflow/tensorflow/issues/33898 . Tôi sẽ cố gắng tái tạo lỗi với một mô hình khác.
Benjamin Breton

1
Hóa ra bạn đã đúng @THN Tôi đã sử dụng numpy và nó đã giải quyết được vấn đề. Tôi sẽ đăng một câu trả lời toàn diện
Benjamin Breton

0

Vấn đề xuất phát từ việc xáo trộn bằng phương pháp tf.Dataset . Nó chỉ xáo trộn thông qua bộ dữ liệu một xô tại thời điểm đó. Sử dụng Keras.Model.fit mang lại kết quả tốt hơn bởi vì nó có thể thêm một sự xáo trộn khác.

Tôi đã thêm một sự xáo trộn với numpy.random.shufflevà nó đã cải thiện hiệu suất với cả hai phương pháp đào tạo:

Thế hệ của bộ dữ liệu bây giờ là:

numpy_data = np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1), index_data.reshape(-1, 1)])

np.random.shuffle(numpy_data)

indexes = np.array(numpy_data[:, :2], dtype=np.uint32)
labels = np.array(numpy_data[:, 2].reshape(-1, 1), dtype=np.float32)

train_ds = data.Dataset.from_tensor_slices(
    (indexes, labels)
).shuffle(100000).batch(batch_size, drop_remainder=True)

Và kết quả là:

+--------+--------------+------------------+
|  Epoch | GradientTape |  keras.Model.fit |
+--------+--------------+------------------+
|    1   |     0.294    |      0.294       |
+--------+--------------+------------------+
|    2   |     0.111    |      0.110       |
+--------+--------------+------------------+
|    3   |     0.089    |      0.089       |
+--------+--------------+------------------+
|    4   |     0.074    |      0.075       |
+--------+--------------+------------------+
|    5   |     0.063    |      0.063       |
+--------+--------------+------------------+

Loại đào tạo trên mỗi epoch gần như giống nhau ở mức 2 phút mỗi epoch .

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.