Docker 部署rerank模型 BAAI/bge-reranker-base

FROM python:3.9-slim

RUN sed -i 's|http://deb.debian.org|https://mirrors.ustc.edu.cn|g' /etc/apt/sources.list.d/debian.sources \
    && apt-get install -y --no-install-recommends tzdata \
    && ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
    && echo "Asia/Shanghai" > /etc/timezone \
    && apt-get update && apt-get install -y gcc \
    && rm -rf /var/lib/apt/lists/* \
    && pip install --no-cache-dir -i https://pypi.mirrors.ustc.edu.cn/simple \
    numpy==1.26.4 torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 sentence-transformers==4.0.1 fastapi==0.115.12 uvicorn==0.34.0

WORKDIR /app

COPY rerank/app.py /app/

EXPOSE 8000

CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]


from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
import os
import time

app = FastAPI(title="BGE-Reranker-Base CPU服务")

# 强制使用CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# 加载base模型(约420MB)
model = CrossEncoder('BAAI/bge-reranker-base', device='cpu')

class QueryDocumentPair(BaseModel):
    query: str
    document: str

class RerankRequest(BaseModel):
    query: str
    documents: list[str]
    top_k: int = None
    batch_size: int = 16  # 默认批处理大小

@app.post("/rerank")
async def rerank_texts(request: RerankRequest):
    start_time = time.time()

    # 安全限制
    MAX_DOCS = 100
    if len(request.documents) > MAX_DOCS:
        raise HTTPException(
            status_code=400,
            detail=f"超过最大处理文档数({MAX_DOCS}),请减少文档数量或分批处理"
        )

    # 准备输入
    model_inputs = [[request.query, doc] for doc in request.documents]

    # 分批处理防止内存溢出
    scores = []
    for i in range(0, len(model_inputs), request.batch_size):
        batch = model_inputs[i:i + request.batch_size]
        scores.extend(model.predict(batch))

    # 组合结果并排序
    results = sorted(
        zip(request.documents, scores),
        key=lambda x: x[1],
        reverse=True
    )

    # 应用top_k限制
    if request.top_k is not None and request.top_k > 0:
        results = results[:request.top_k]

    processing_time = time.time() - start_time

    return {
        "model": "bge-reranker-base",
        "device": "cpu",
        "processing_time_seconds": round(processing_time, 3),
        "documents_processed": len(request.documents),
        "results": [
            {"document": doc, "score": float(score), "rank": idx+1}
            for idx, (doc, score) in enumerate(results)
        ]
    }

@app.get("/model-info")
async def get_model_info():
    return {
        "model_name": "BAAI/bge-reranker-base",
        "max_sequence_length": 512,
        "recommended_batch_size": 16,
        "device": "cpu"
    }

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": True}

发表评论

电子邮件地址不会被公开。 必填项已用*标注

Captcha Code