Hàm tf.nn.embpping_lookup làm gì?


158
tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None)

Tôi không thể hiểu nhiệm vụ của chức năng này. Nó giống như một bảng tra cứu? Có nghĩa là trả về các tham số tương ứng với mỗi id (tính bằng id)?

Chẳng hạn, trong skip-grammô hình nếu chúng ta sử dụng tf.nn.embedding_lookup(embeddings, train_inputs), thì với mỗi mô hình train_inputnó có tìm thấy sự nhúng tương ứng không?


"Nó giống như một bảng tra cứu?" tldr - Có. Với mỗi x (id) hãy cho tôi liên kết y (params).
David Refaeli

Câu trả lời:


147

embedding_lookuphàm lấy các hàng của paramstenxơ. Hành vi tương tự như sử dụng lập chỉ mục với các mảng trong numpy. Ví dụ

matrix = np.random.random([1024, 64])  # 64-dimensional embeddings
ids = np.array([0, 5, 17, 33])
print matrix[ids]  # prints a matrix of shape [4, 64] 

paramsđối số cũng có thể là một danh sách các tenxơ trong trường hợp đó idssẽ được phân phối giữa các tenxơ. Ví dụ, đưa ra một danh sách 3 tensors [2, 64], hành vi mặc định là họ sẽ đại diện ids: [0, 3], [1, 4], [2, 5].

partition_strategykiểm soát cách thức idsphân phối giữa các danh sách. Việc phân vùng rất hữu ích cho các vấn đề quy mô lớn hơn khi ma trận có thể quá lớn để giữ thành một mảnh.


21
Tại sao họ gọi nó theo cách này mà không phải select_rows?
Lenar Hoyt

12
@LenarHoyt vì ý tưởng tra cứu này xuất phát từ Word Embeddings. và "hàng" là các đại diện (nhúng) của các từ, vào một không gian vectơ - và rất hữu ích trong chính chúng. Thường là nhiều hơn so với mạng thực tế.
Lyndon White

2
Làm thế nào để tenorflow tìm hiểu cấu trúc nhúng? Liệu chức năng này có quản lý quá trình đó quá không?
vgoklani

19
@vgoklani, không, embedding_lookupchỉ đơn giản là cung cấp một cách thuận tiện (và song song) để truy xuất các nhúng tương ứng với id trong ids. Các paramstensor thường là một biến tf được học như một phần của quá trình đào tạo - một biến tf mà các thành phần được sử dụng trực tiếp hoặc gián tiếp, trong một chức năng mất (như tf.l2_loss) được tối ưu hóa bởi một ưu (như tf.train.AdamOptimizer).
Shobhit

5
@ Rafał Józefowicz Tại sao "hành vi mặc định là họ sẽ đại diện cho id: [0, 3], [1, 4], [2, 5]."? Bạn có thể giải thích?
Aerin

219

Vâng, chức năng này là khó hiểu, cho đến khi bạn nhận được điểm.

Ở dạng đơn giản nhất, nó tương tự như tf.gather. Nó trả về các phần tử paramstheo các chỉ mục được chỉ định bởi ids.

Ví dụ (giả sử bạn đang ở trong tf.InteractiveSession())

params = tf.constant([10,20,30,40])
ids = tf.constant([0,1,2,3])
print tf.nn.embedding_lookup(params,ids).eval()

sẽ trả về [10 20 30 40], bởi vì phần tử đầu tiên (chỉ số 0) của params là 10, phần tử thứ hai của params (index 1) là 20, v.v.

Tương tự

params = tf.constant([10,20,30,40])
ids = tf.constant([1,1,3])
print tf.nn.embedding_lookup(params,ids).eval()

sẽ quay trở lại [20 20 40].

Nhưng embedding_lookupcòn hơn thế nữa. Đối paramssố có thể là một danh sách các tenxơ, chứ không phải là một tenxơ đơn.

params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)

Trong trường hợp như vậy, các chỉ mục, được chỉ định trong ids, tương ứng với các phần tử của tenxơ theo chiến lược phân vùng , trong đó chiến lược phân vùng mặc định là 'mod'.

Trong chiến lược 'mod', chỉ số 0 tương ứng với phần tử đầu tiên của tenxơ đầu tiên trong danh sách. Chỉ số 1 tương ứng với phần tử đầu tiên của tenxơ thứ hai . Chỉ số 2 tương ứng với phần tử đầu tiên của tenxơ thứ ba , v.v. Chỉ mục đơn giản itương ứng với phần tử đầu tiên của tenor thứ (i + 1), đối với tất cả các chỉ mục 0..(n-1), giả sử params là một danh sách các ntenxơ.

Bây giờ, chỉ mục nkhông thể tương ứng với tenor n + 1, vì danh sách paramschỉ chứa ncác tenxơ. Vì vậy, chỉ số ntương ứng với các yếu tố thứ hai của tenor đầu tiên. Tương tự, chỉ số n+1tương ứng với phần tử thứ hai của tenxơ thứ hai, v.v.

Vì vậy, trong mã

params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)

chỉ số 0 tương ứng với phần tử đầu tiên của tenxơ thứ nhất: 1

chỉ số 1 tương ứng với phần tử thứ nhất của tenxơ thứ hai: 10

chỉ số 2 tương ứng với phần tử thứ hai của tenxơ thứ nhất: 2

chỉ số 3 tương ứng với phần tử thứ hai của tenxơ thứ hai: 20

Do đó, kết quả sẽ là:

[ 2  1  2 10  2 20]

8
một lưu ý: bạn có thể sử dụng partition_strategy='div'và sẽ nhận được [10, 1, 10, 2, 10, 20], tức id=1là phần tử thứ hai của param đầu tiên. Về cơ bản: partition_strategy=mod(mặc định) id%len(params): chỉ mục của param trong params id//len(params): chỉ mục của phần tử theo thông số trên partition_strategy=*div*theo cách khác
Mario Alemi

3
@ asher-stern bạn có thể giải thích tại sao chiến lược "mod" là mặc định không? dường như chiến lược "div" tương tự như cắt lát căng tiêu chuẩn (hàng chọn theo chỉ số đã cho). Có một số vấn đề hiệu suất trong trường hợp "div"?
svetlov.vsevolod

46

Có, mục đích của tf.nn.embedding_lookup()chức năng là thực hiện tra cứu trong ma trận nhúng và trả về các nhúng (hoặc theo thuật ngữ đơn giản là biểu diễn vectơ) của các từ.

Một ma trận nhúng đơn giản (có hình dạng vocabulary_size x embedding_dimension:) sẽ trông như dưới đây. (tức là mỗi từ sẽ được biểu thị bằng một vectơ số; do đó tên word2vec )


Ma trận nhúng

the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862
like 0.36808 0.20834 -0.22319 0.046283 0.20098 0.27515 -0.77127 -0.76804
between 0.7503 0.71623 -0.27033 0.20059 -0.17008 0.68568 -0.061672 -0.054638
did 0.042523 -0.21172 0.044739 -0.19248 0.26224 0.0043991 -0.88195 0.55184
just 0.17698 0.065221 0.28548 -0.4243 0.7499 -0.14892 -0.66786 0.11788
national -1.1105 0.94945 -0.17078 0.93037 -0.2477 -0.70633 -0.8649 -0.56118
day 0.11626 0.53897 -0.39514 -0.26027 0.57706 -0.79198 -0.88374 0.30119
country -0.13531 0.15485 -0.07309 0.034013 -0.054457 -0.20541 -0.60086 -0.22407
under 0.13721 -0.295 -0.05916 -0.59235 0.02301 0.21884 -0.34254 -0.70213
such 0.61012 0.33512 -0.53499 0.36139 -0.39866 0.70627 -0.18699 -0.77246
second -0.29809 0.28069 0.087102 0.54455 0.70003 0.44778 -0.72565 0.62309 

Tôi chia ma trận nhúng ở trên và chỉ tải các từ trong vocabđó sẽ là từ vựng của chúng tôi và các vectơ tương ứng trong embmảng.

vocab = ['the','like','between','did','just','national','day','country','under','such','second']

emb = np.array([[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.044457, -0.49688, -0.17862],
   [0.36808, 0.20834, -0.22319, 0.046283, 0.20098, 0.27515, -0.77127, -0.76804],
   [0.7503, 0.71623, -0.27033, 0.20059, -0.17008, 0.68568, -0.061672, -0.054638],
   [0.042523, -0.21172, 0.044739, -0.19248, 0.26224, 0.0043991, -0.88195, 0.55184],
   [0.17698, 0.065221, 0.28548, -0.4243, 0.7499, -0.14892, -0.66786, 0.11788],
   [-1.1105, 0.94945, -0.17078, 0.93037, -0.2477, -0.70633, -0.8649, -0.56118],
   [0.11626, 0.53897, -0.39514, -0.26027, 0.57706, -0.79198, -0.88374, 0.30119],
   [-0.13531, 0.15485, -0.07309, 0.034013, -0.054457, -0.20541, -0.60086, -0.22407],
   [ 0.13721, -0.295, -0.05916, -0.59235, 0.02301, 0.21884, -0.34254, -0.70213],
   [ 0.61012, 0.33512, -0.53499, 0.36139, -0.39866, 0.70627, -0.18699, -0.77246 ],
   [ -0.29809, 0.28069, 0.087102, 0.54455, 0.70003, 0.44778, -0.72565, 0.62309 ]])


emb.shape
# (11, 8)

Tra cứu nhúng trong TensorFlow

Bây giờ chúng ta sẽ xem làm thế nào chúng ta có thể thực hiện tra cứu nhúng cho một số câu đầu vào tùy ý.

In [54]: from collections import OrderedDict

# embedding as TF tensor (for now constant; could be tf.Variable() during training)
In [55]: tf_embedding = tf.constant(emb, dtype=tf.float32)

# input for which we need the embedding
In [56]: input_str = "like the country"

# build index based on our `vocabulary`
In [57]: word_to_idx = OrderedDict({w:vocab.index(w) for w in input_str.split() if w in vocab})

# lookup in embedding matrix & return the vectors for the input words
In [58]: tf.nn.embedding_lookup(tf_embedding, list(word_to_idx.values())).eval()
Out[58]: 
array([[ 0.36807999,  0.20834   , -0.22318999,  0.046283  ,  0.20097999,
         0.27515   , -0.77126998, -0.76804   ],
       [ 0.41800001,  0.24968   , -0.41242   ,  0.1217    ,  0.34527001,
        -0.044457  , -0.49687999, -0.17862   ],
       [-0.13530999,  0.15485001, -0.07309   ,  0.034013  , -0.054457  ,
        -0.20541   , -0.60086   , -0.22407   ]], dtype=float32)

Quan sát cách chúng tôi có các phần nhúng từ ma trận nhúng ban đầu (có từ) bằng cách sử dụng các chỉ số của từ trong từ vựng của chúng tôi.

Thông thường, việc tra cứu nhúng như vậy được thực hiện bởi lớp đầu tiên (được gọi là lớp Nhúng ) sau đó chuyển các nhúng này sang các lớp RNN / LSTM / GRU để xử lý thêm.


Lưu ý bên lề : Thông thường từ vựng cũng sẽ có unkmã thông báo đặc biệt . Vì vậy, nếu mã thông báo từ câu đầu vào của chúng tôi không có trong từ vựng của chúng tôi, thì chỉ mục tương ứng unksẽ được tra cứu trong ma trận nhúng.


PS Lưu ý rằng đó embedding_dimensionlà một siêu tham số mà người ta phải điều chỉnh cho ứng dụng của mình nhưng các mô hình phổ biến như Word2VecGloVe sử dụng 300vectơ kích thước để thể hiện mỗi từ.

Thưởng đọc mô hình bỏ qua word2vec


17

Đây là một hình ảnh mô tả quá trình nhúng tra cứu.

Hình: Quá trình tra cứu nhúng

Chính xác, nó nhận được các hàng tương ứng của một lớp nhúng, được chỉ định bởi một danh sách ID và cung cấp đó như là một tenxơ. Nó đạt được thông qua quá trình sau đây.

  1. Xác định một giữ chỗ lookup_ids = tf.placeholder([10])
  2. Xác định một lớp nhúng embeddings = tf.Variable([100,10],...)
  3. Xác định hoạt động kéo căng embed_lookup = tf.embedding_lookup(embeddings, lookup_ids)
  4. Nhận kết quả bằng cách chạy lookup = session.run(embed_lookup, feed_dict={lookup_ids:[95,4,14]})

6

Khi tenx params ở kích thước cao, id chỉ đề cập đến kích thước trên cùng. Có thể nó rõ ràng với hầu hết mọi người nhưng tôi phải chạy đoạn mã sau để hiểu rằng:

embeddings = tf.constant([[[1,1],[2,2],[3,3],[4,4]],[[11,11],[12,12],[13,13],[14,14]],
                          [[21,21],[22,22],[23,23],[24,24]]])
ids=tf.constant([0,2,1])
embed = tf.nn.embedding_lookup(embeddings, ids, partition_strategy='div')

with tf.Session() as session:
    result = session.run(embed)
    print (result)

Chỉ cần thử chiến lược 'div' và cho một tenor, nó không có gì khác biệt.

Đây là đầu ra:

[[[ 1  1]
  [ 2  2]
  [ 3  3]
  [ 4  4]]

 [[21 21]
  [22 22]
  [23 23]
  [24 24]]

 [[11 11]
  [12 12]
  [13 13]
  [14 14]]]

3

Một cách khác để xem xét nó là, giả sử rằng bạn làm phẳng các tenxơ thành một mảng một chiều, và sau đó bạn đang thực hiện tra cứu

(ví dụ) Tensor0 = [1,2,3], Tensor1 = [4,5,6], Tensor2 = [7,8,9]

Các tenxơ phẳng sẽ như sau [1,4,7,2,5,8,3,6,9]

Bây giờ khi bạn thực hiện tra cứu [0,3,4,1,7], nó sẽ có hiệu lực [1,2,5,4,6]

(i, e) nếu giá trị tra cứu là 7 chẳng hạn, và chúng ta có 3 thang đo (hoặc một thang đo có 3 hàng) thì,

7/3: (Nhắc nhở là 1, Quotient là 2) Vì vậy, phần tử thứ 2 của Tensor1 sẽ được hiển thị, đó là 6


2

Vì tôi cũng bị hấp dẫn bởi chức năng này, tôi sẽ đưa hai xu của mình.

Cách tôi nhìn thấy nó trong trường hợp 2D chỉ là phép nhân ma trận (thật dễ dàng để khái quát hóa cho các kích thước khác).

Xem xét một từ vựng với N ký hiệu. Sau đó, bạn có thể biểu thị một ký hiệu x là một vectơ có kích thước Nx1, được mã hóa một lần nóng.

Nhưng bạn muốn một đại diện của biểu tượng này không phải là một vectơ của Nx1, mà là một biểu tượng có kích thước Mx1, được gọi là y .

Vì vậy, để chuyển đổi x thành y , bạn có thể sử dụng và nhúng ma trận E , với kích thước MxN:

y = E x .

Đây thực chất là gì tf.nn.embedding_lookup (params, id, ...) đang làm, với sắc thái mà id chỉ là một số đại diện cho vị trí của 1 trong một hot-mã hóa vector x .


0

Thêm vào câu trả lời của Asher Stern, paramsđược hiểu là phân vùng của một tenxơ nhúng lớn. Nó có thể là một tenxơ duy nhất đại diện cho tenxơ nhúng hoàn chỉnh, hoặc một danh sách các tenxơ X có cùng hình dạng ngoại trừ kích thước đầu tiên, đại diện cho các tenxơ nhúng bị cắt.

Hàm tf.nn.embedding_lookupđược viết xem xét thực tế rằng nhúng (params) sẽ lớn. Vì vậy, chúng tôi cần partition_strategy.

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.