Re-rank 算法: Maximal Marginal Relevance

最近在工作上實作了 Maximal Marginal Relevance (MMR) 的算法,主要用在推薦系統的 rerank,這個算法的結構跟想法都簡單,但效果算是非常不錯(至少在我的任務上),整體非常雋永,特別介紹給大家,尤其在 LLM 當道的時代,RAG 與 搜尋、廣告、推薦其實都是互通有無的。

先來看一個案例,我分別在 PChome 與 momo 上面搜尋「牛奶」,PChome 在前 10 個商品以內有兩組重複的,momo 則在前 10 個有一組重複的商品,當然商品排序需要考慮的因素眾多,但至少就「多樣性」來說,momo在這一方面或許就相對好一點,畢竟消費者可能沒那麼在乎今天到底是要買一箱?還是買兩箱?他們可能更在乎我有多少種類的東西可以選擇。

PChome – milk
momo – milk

像這樣子的排序問題,就直接地影響到消費者對於這個電商的看法。一般來說在使用者搜尋之後,系統會從各路檢索出相關的商品,經過一系列的排序再在頁面上面呈現給使用者,而在這一個過程當中,候選商品會逐漸減少,而 re-rank 通常發生在整個搜尋系統的下游,也就是我們現在手上已經有一批商品候選了,我們應該要如何排序這些商品?

MMR 就是就是一個非常經典且實用的算法,定義如下:

$$ \arg \max_{D_i \in \mathbb{R} \backslash \mathbb{S}} \left[ (1 – \lambda) \overbrace{\text{sim}(D_i, Q) }^\textrm{similarity} – \lambda \overbrace{\max_{D_j \in \mathbb{S}} \text{sim}(D_i, D_j) }^\textrm{diversity} \right]$$

  • $Q$: 表示一個 query
  • $D$: 表示一個文本
  • $\mathbb{R}$: 表示所有的文本的集合
  • $\mathbb{S}$: 表示已經選擇的文本集合
  • $\text{sim}(p, q)$: 用來計算 p, q 的相似分數
  • $\lambda \in [0, 1]$: 多樣性的權重

在這一個 argmax 中,我們會得到一個文本的 index $i$,使得在這一輪中,他與 query 有一定程度的相似,同時也與已經選擇的文本集合有一定程度的差異,這個 similariry 與 diversity 的 trade-off 主要是透過 $\lambda$ 來控制,要是我們執行 k 個 iteration,最後就會得到有 k 個文本的 re-rank 結果。

Python Implementation

完整流程與程式碼請參考 notebook: GitHub.

Dataset

我們利用一個公開資料集「CAS產品查詢」來模擬一下搜尋特定產品,並且利用 MMR 來排序的過程。我們將這個資料集裡面的產品名稱 Product_Name 當作搜尋的標的,這個資料集大概是長這個樣子:

為了實現依照相似度搜尋的功能,我們將 Product_Name 與 query 透過 embedding model 轉換成 1024 維度的向量,並且利用 consine similarity 進行語義搜尋。

  • Embedding: Qwen3-Embedding-0.6B (Dimension: 1024)
  • Package: sentence-transformers
  • Similarity: cosine similarity

Search and Re-rank

模擬使用者搜尋的過程如下:

  1. 使用者輸入一個搜尋的字詞 query
  2. queryProduct_Name 經過 embedding model 投影到一個 1024 維的向量
  3. 利用 cosine similarity 找出前 30 個跟 query 相似的商品
  4. 再利用 MMR 重新排序這 30 個商品,最後回傳 10 個商品結果

至於 MMR,有別於經典的方式,我們實作一個有 sliding window 的版本,在每一個 iteration 中,當考量 diversity 時,我們只考量前 m 個已經選擇的商品。


Python
def mmr(
    query_embedding: np.ndarray,
    document_embeddings: np.ndarray,
    diversity: float = 0.1,
    top_n: int = 10,
    window_size: int | None = None
) -> list[str]:
    """Maximal Marginal Relevance (with sliding window).

    Arguments:
        query_embedding: The query embedding
        document_embeddings: The embeddings of the selected documents
        diversity: The diversity of the selected embeddings. Values between 0 and 1.
        top_n: The top n items to return
        window_size: The size of the sliding window

    Returns:
            list[int]: The indices of the selected documents
    """
    from sklearn.metrics.pairwise import cosine_similarity

    # compute similarity(Q, D) and similarity(D, D)
    query_doc_similarity = cosine_similarity([query_embedding], document_embeddings)[0]
    pair_similarity = cosine_similarity(document_embeddings)

    if window_size is None:
        window_size = min(10, len(document_embeddings))

    # return doc_idx as the result and recode candidates_idx as current candidate set
    doc_idx = [np.argmax(query_doc_similarity)]
    candidates_idx = [i for i in range(len(document_embeddings)) if i != doc_idx[0]]
    for _ in range(top_n - 1):
        # in each iteration, select one documnet within candidates using MMR
        candidate_similarities = query_doc_similarity[candidates_idx]
        target_similarities = np.max(pair_similarity[candidates_idx][:, doc_idx[-window_size:]], axis=1)

        # calculate MMR
        mmr = (1 - diversity) * candidate_similarities - diversity * target_similarities
        mmr_idx = candidates_idx[np.argmax(mmr)]

        # Update doc_idx & candidates
        doc_idx.append(mmr_idx)
        candidates_idx.remove(mmr_idx)

    return doc_idx

Python
def search(
    query: str,
    using_mmr: bool = True,
    window_size: int | None = None,
    diversity: float = 0.1
) -> list[str]:
    query_embedding = model.encode(query, prompt_name="query")
    similarity_scores = model.similarity(query_embedding, document_embeddings)[0]
    
    indices = np.argsort(similarity_scores.tolist())[::-1]
    
    if not using_mmr:   
        return df.Product_Name[indices[:10]]
    
    doc_idx = mmr(
        query_embedding,
        document_embeddings[indices[:30]],
        top_n=10,
        window_size=window_size,
        diversity=diversity
    )
    return df.Product_Name[indices[doc_idx]]

詳細的 code 還請參照 GitHub

Example

  • Different Diversity
    搜尋「鮮乳」,比較不同 diversity 權重的結果
Python
pd.DataFrame({
    "Diversity=0.1": search("鮮乳", diversity=0.1).to_list(),
    "Diversity=0.9": search("鮮乳", diversity=0.9).to_list()
})

很明顯的可以看到當 diversity 低的時候,搜尋到的產品很有可能會重複。

  • Different window size
    比較同樣 diversity 時,啟用 sliding window 與否的結果
Python
pd.DataFrame({
    "window_size=None": search(
        "鮮乳", diversity=0.6, window_size=None
    ).to_list(),
    "window_size=4": search(
        "鮮乳", diversity=0.6, window_size=4
    ).to_list(),
})

window size = 4 時,第 0 與第 5 個產品都是「四方鮮乳全脂鮮乳」;第 4 與第 9 個產品都是「光泉鮮乳-成分無調整」,而這中間的間隔就是我們設定的 window size。當然,diversitywindow_size 都是需要被設計且調整的參數,實務上還是要多觀察才會比較知道要怎麼設定。

Conclusion

實在很喜歡這種小品算法,但算法本身的好壞,完全與商業指標相關,要是衡量的指標全是那種教科書上會出現的,那也真的不是一個很好的實踐。

搜廣推的算法領域實在有太多有趣、需要學習、想要學習的東西了~~