FROM python:3.9-slim
RUN sed -i 's|http://deb.debian.org|https://mirrors.ustc.edu.cn|g' /etc/apt/sources.list.d/debian.sources \
&& apt-get install -y --no-install-recommends tzdata \
&& ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone \
&& apt-get update && apt-get install -y gcc \
&& rm -rf /var/lib/apt/lists/* \
&& pip install --no-cache-dir -i https://pypi.mirrors.ustc.edu.cn/simple \
numpy==1.26.4 torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 sentence-transformers==4.0.1 fastapi==0.115.12 uvicorn==0.34.0
WORKDIR /app
COPY rerank/app.py /app/
EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
import os
import time
app = FastAPI(title="BGE-Reranker-Base CPU服务")
# 强制使用CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# 加载base模型(约420MB)
model = CrossEncoder('BAAI/bge-reranker-base', device='cpu')
class QueryDocumentPair(BaseModel):
query: str
document: str
class RerankRequest(BaseModel):
query: str
documents: list[str]
top_k: int = None
batch_size: int = 16 # 默认批处理大小
@app.post("/rerank")
async def rerank_texts(request: RerankRequest):
start_time = time.time()
# 安全限制
MAX_DOCS = 100
if len(request.documents) > MAX_DOCS:
raise HTTPException(
status_code=400,
detail=f"超过最大处理文档数({MAX_DOCS}),请减少文档数量或分批处理"
)
# 准备输入
model_inputs = [[request.query, doc] for doc in request.documents]
# 分批处理防止内存溢出
scores = []
for i in range(0, len(model_inputs), request.batch_size):
batch = model_inputs[i:i + request.batch_size]
scores.extend(model.predict(batch))
# 组合结果并排序
results = sorted(
zip(request.documents, scores),
key=lambda x: x[1],
reverse=True
)
# 应用top_k限制
if request.top_k is not None and request.top_k > 0:
results = results[:request.top_k]
processing_time = time.time() - start_time
return {
"model": "bge-reranker-base",
"device": "cpu",
"processing_time_seconds": round(processing_time, 3),
"documents_processed": len(request.documents),
"results": [
{"document": doc, "score": float(score), "rank": idx+1}
for idx, (doc, score) in enumerate(results)
]
}
@app.get("/model-info")
async def get_model_info():
return {
"model_name": "BAAI/bge-reranker-base",
"max_sequence_length": 512,
"recommended_batch_size": 16,
"device": "cpu"
}
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": True}