Làm thế nào là máy phát điện trong một GAN được đào tạo?


9

Bài viết trên GAN cho biết người phân biệt đối xử sử dụng gradient sau để huấn luyện:

θd1mΣTôi= =1m[đăng nhậpD(x(Tôi))+đăng nhập(1-D(G(z(Tôi))))]

Các giá trị được lấy mẫu, được chuyển qua bộ tạo để tạo các mẫu dữ liệu và sau đó bộ phân biệt được sao lưu bằng cách sử dụng các mẫu dữ liệu được tạo. Một khi trình tạo tạo dữ liệu, nó không đóng vai trò gì nữa trong việc đào tạo người phân biệt đối xử. Nói cách khác, trình tạo có thể được loại bỏ hoàn toàn khỏi số liệu bằng cách tạo ra các mẫu dữ liệu và sau đó chỉ làm việc với các mẫu.z

Tôi hơi bối rối hơn về cách máy phát điện được đào tạo mặc dù. Nó sử dụng gradient sau:

θg1mΣTôi= =1m[đăng nhập(1-D(G(z(Tôi))))]

Trong trường hợp này, người phân biệt đối xử là một phần của số liệu. Nó không thể được gỡ bỏ như trường hợp trước. Những thứ như bình phương tối thiểu hoặc khả năng đăng nhập trong các mô hình phân biệt đối xử thông thường có thể dễ dàng được phân biệt bởi vì chúng có một định nghĩa gần đúng, đẹp. Tuy nhiên, tôi hơi bối rối về cách bạn backpropogate khi số liệu phụ thuộc vào mạng thần kinh khác. Về cơ bản, bạn có gắn các đầu ra của máy phát điện vào các đầu vào của bộ phân biệt đối xử và sau đó coi toàn bộ mọi thứ giống như một mạng khổng lồ trong đó các trọng số trong phần phân biệt không đổi?

Câu trả lời:


10

Nó giúp nghĩ về quá trình này trong mã giả. Hãy generator(z)là một hàm lấy một vectơ nhiễu được lấy mẫu đồng đều zvà trả về một vectơ có cùng kích thước với vectơ đầu vào X; Hãy gọi chiều dài này d. Hãy discriminator(x)là một hàm lấy một dvectơ chiều và trả về xác suất vô hướng xthuộc về phân phối dữ liệu thực. Cho tập huấn:

G_sample = generator(Z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

D_loss = maximize mean of (log(D_real) + log(1 - D_fake))
G_loss = maximize mean of log(D_fake)

# Only update D(X)'s parameters
D_solver = Optimizer().minimize(D_loss, theta_D)
# Only update G(X)'s parameters
G_solver = Optimizer().minimize(G_loss, theta_G)

# theta_D and theta_G are the weights and biases of D and G respectively
Repeat the above for a number of epochs

Vì vậy, vâng, bạn đúng khi về cơ bản chúng tôi nghĩ về trình tạo và phân biệt đối xử như một mạng khổng lồ để xen kẽ các xe buýt nhỏ khi chúng tôi sử dụng dữ liệu giả mạo. Hàm mất của máy phát sẽ chăm sóc độ dốc cho nửa này. Nếu bạn nghĩ đến việc đào tạo mạng này một cách cô lập, thì nó được đào tạo giống như bạn thường huấn luyện MLP với đầu vào là đầu ra của lớp cuối cùng của mạng máy phát.

Bạn có thể theo dõi một lời giải thích chi tiết với mã trong Tensorflow tại đây (trong số nhiều nơi): http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/

Nó sẽ dễ dàng để làm theo một khi bạn nhìn vào mã.


1
Bạn có thể giải thích trên D_lossG_loss? Tối đa hóa không gian gì? IIUC D_realD_fakemỗi đợt là một đợt, vì vậy chúng tôi sẽ tối đa hóa theo đợt ??
P i

@Pi Vâng, chúng tôi đang tối đa hóa một đợt.
tejaskhot

1

Về cơ bản, bạn có gắn các đầu ra của máy phát điện vào các đầu vào của bộ phân biệt đối xử không?> Và sau đó coi toàn bộ mọi thứ giống như một mạng khổng lồ trong đó các trọng số trong phần> phân biệt đối xử là không đổi?

Ngắn gọn: Có. (Tôi đã đào một số nguồn của GAN để kiểm tra lại điều này)

Ngoài ra còn có nhiều hơn nữa về đào tạo GAN như: chúng ta nên cập nhật D và G mỗi lần hoặc D trên các lần lặp lẻ và G trên chẵn, và nhiều hơn nữa. Ngoài ra còn có một bài viết rất hay về chủ đề này:

"Cải tiến kỹ thuật cho GAN đào tạo"

Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen

https://arxiv.org/abs/1606.03498


Bạn có thể vui lòng cung cấp các liên kết đến các nguồn bạn nhìn vào? Nó sẽ hữu ích cho tôi để đọc chúng.
Vivek Subramanian

0

Gần đây tôi đã tải lên bộ sưu tập các mô hình GAN khác nhau trên repo github. Nó dựa trên Torch7, và rất dễ chạy. Mã này đủ đơn giản để hiểu với kết quả thử nghiệm. Hy vọng điều này sẽ giúp

https://github.com/nashory/gans-collection.torch

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.