556 lines
20 KiB
Python
556 lines
20 KiB
Python
"""
|
||
LLM 论文图书馆 — FastAPI 后端
|
||
提供 REST API 进行论文查询、管理、PDF 代理服务
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import hashlib
|
||
import secrets
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
from fastapi import FastAPI, HTTPException, Query, Depends, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, JSONResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel
|
||
|
||
# ─── Config ────────────────────────────────────────────
|
||
ROOT = Path(__file__).resolve().parent.parent
|
||
DATA_FILE = ROOT / "data" / "papers.json"
|
||
PAPERS_DIR = ROOT / "papers"
|
||
API_KEY = os.environ.get("LLM_LIB_API_KEY", "change-me")
|
||
|
||
log = logging.getLogger("llm-library")
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||
|
||
# ─── App ───────────────────────────────────────────────
|
||
app = FastAPI(
|
||
title="LLM 论文图书馆",
|
||
description="大模型论文知识库 API — 查询、搜索、管理论文",
|
||
version="0.1.0",
|
||
)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# ─── Auth ──────────────────────────────────────────────
|
||
def verify_api_key(request: Request):
|
||
"""简单的 API Key 鉴权 — 用于写操作 (POST/PUT/DELETE)"""
|
||
auth = request.headers.get("Authorization", "")
|
||
if auth.startswith("Bearer "):
|
||
token = auth[7:]
|
||
else:
|
||
token = request.query_params.get("api_key", "")
|
||
if not token or token != API_KEY:
|
||
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
||
return True
|
||
|
||
# ─── Data loading ──────────────────────────────────────
|
||
def load_data():
|
||
if not DATA_FILE.exists():
|
||
return {}
|
||
with open(DATA_FILE, 'r') as f:
|
||
return json.load(f)
|
||
|
||
def save_data(data):
|
||
with open(DATA_FILE, 'w') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
# ─── Paper CRUD helpers ────────────────────────────────
|
||
def find_paper(data, module_id, area_id, title):
|
||
"""Find a paper index by title within module/area"""
|
||
mod = data.get(module_id)
|
||
if not mod:
|
||
return None, None, None, None
|
||
for area in mod.get("areas", []):
|
||
if area["id"] == area_id:
|
||
for section in ("mainline", "branches", "forward"):
|
||
for i, p in enumerate(area.get(section, [])):
|
||
if p["title"] == title:
|
||
return mod, area, section, i
|
||
return None, None, None, None
|
||
|
||
# ─── Routes: Query ─────────────────────────────────────
|
||
@app.get("/api/stats")
|
||
def get_stats():
|
||
"""获取图书馆统计信息"""
|
||
data = load_data()
|
||
mods = len(data)
|
||
areas = 0
|
||
papers = 0
|
||
sections = {"mainline": 0, "branches": 0, "forward": 0}
|
||
for mod in data.values():
|
||
areas += len(mod.get("areas", []))
|
||
for area in mod.get("areas", []):
|
||
for s in ("mainline", "branches", "forward"):
|
||
n = len(area.get(s, []))
|
||
papers += n
|
||
sections[s] += n
|
||
return {
|
||
"modules": mods,
|
||
"areas": areas,
|
||
"papers": papers,
|
||
"sections": sections,
|
||
"data_file": str(DATA_FILE),
|
||
}
|
||
|
||
@app.get("/api/modules")
|
||
def list_modules():
|
||
"""列出所有模块 (不含论文详情)"""
|
||
data = load_data()
|
||
return [
|
||
{
|
||
"id": mid,
|
||
"name": m["name"],
|
||
"icon": m["icon"],
|
||
"desc": m["desc"],
|
||
"area_count": len(m.get("areas", [])),
|
||
"paper_count": sum(
|
||
len(a.get("mainline", [])) + len(a.get("branches", [])) + len(a.get("forward", []))
|
||
for a in m.get("areas", [])
|
||
),
|
||
}
|
||
for mid, m in data.items()
|
||
]
|
||
|
||
@app.get("/api/modules/{module_id}")
|
||
def get_module(module_id: str):
|
||
"""获取单个模块的完整论文数据"""
|
||
data = load_data()
|
||
mod = data.get(module_id)
|
||
if not mod:
|
||
raise HTTPException(status_code=404, detail=f"Module '{module_id}' not found")
|
||
return mod
|
||
|
||
@app.get("/api/papers")
|
||
def search_papers(
|
||
q: str = Query(default="", description="搜索关键词: 标题/作者"),
|
||
module: Optional[str] = Query(default=None),
|
||
tag: Optional[str] = Query(default=None, description="起点/关键节点/前沿/前瞻/支线"),
|
||
limit: int = Query(default=50, ge=1, le=200),
|
||
):
|
||
"""搜索论文 (全文/按模块/按标签)"""
|
||
data = load_data()
|
||
results = []
|
||
q = q.lower()
|
||
for mid, mod in data.items():
|
||
if module and mid != module:
|
||
continue
|
||
for area in mod.get("areas", []):
|
||
for section in ("mainline", "branches", "forward"):
|
||
for p in area.get(section, []):
|
||
# Filter by tag
|
||
if tag and tag not in p.get("tags", []):
|
||
continue
|
||
# Filter by query
|
||
if q:
|
||
if q not in (p.get("title", "") + p.get("authors", "")).lower():
|
||
continue
|
||
results.append({
|
||
"module_id": mid,
|
||
"module_name": mod["name"],
|
||
"area_id": area["id"],
|
||
"area_name": area["name"],
|
||
"section": section,
|
||
**p,
|
||
})
|
||
if len(results) >= limit:
|
||
break
|
||
if len(results) >= limit:
|
||
break
|
||
if len(results) >= limit:
|
||
break
|
||
if len(results) >= limit:
|
||
break
|
||
return results
|
||
|
||
# ─── Routes: Management (写操作, 需 API Key) ────────────
|
||
class PaperCreate(BaseModel):
|
||
module_id: str
|
||
area_id: str
|
||
section: str = "mainline" # mainline / branches / forward
|
||
title: str
|
||
authors: str = ""
|
||
year: int
|
||
venue: str = ""
|
||
arxiv: Optional[str] = None
|
||
pdf: Optional[str] = None
|
||
tags: list[str] = []
|
||
|
||
class PaperUpdate(BaseModel):
|
||
authors: Optional[str] = None
|
||
year: Optional[int] = None
|
||
venue: Optional[str] = None
|
||
arxiv: Optional[str] = None
|
||
pdf: Optional[str] = None
|
||
tags: Optional[list[str]] = None
|
||
section: Optional[str] = None # move to different section
|
||
|
||
@app.post("/api/papers", dependencies=[Depends(verify_api_key)])
|
||
def add_paper(paper: PaperCreate):
|
||
"""添加一篇新论文"""
|
||
data = load_data()
|
||
mod = data.get(paper.module_id)
|
||
if not mod:
|
||
raise HTTPException(status_code=404, detail="Module not found")
|
||
|
||
area = next((a for a in mod["areas"] if a["id"] == paper.area_id), None)
|
||
if not area:
|
||
raise HTTPException(status_code=404, detail="Area not found")
|
||
|
||
section = paper.section
|
||
if section not in ("mainline", "branches", "forward"):
|
||
raise HTTPException(status_code=400, detail="section must be mainline/branches/forward")
|
||
|
||
entry = {
|
||
"title": paper.title,
|
||
"authors": paper.authors,
|
||
"year": paper.year,
|
||
"venue": paper.venue,
|
||
"tags": paper.tags,
|
||
}
|
||
if paper.arxiv:
|
||
entry["arxiv"] = paper.arxiv
|
||
if paper.pdf:
|
||
entry["pdf"] = paper.pdf
|
||
|
||
area.setdefault(section, []).append(entry)
|
||
save_data(data)
|
||
log.info(f"Added paper: {paper.title}")
|
||
return {"ok": True, "title": paper.title}
|
||
|
||
@app.put("/api/papers")
|
||
def update_paper(
|
||
module_id: str,
|
||
area_id: str,
|
||
title: str,
|
||
update: PaperUpdate,
|
||
_=Depends(verify_api_key),
|
||
):
|
||
"""更新一篇论文"""
|
||
data = load_data()
|
||
mod, area, section, idx = find_paper(data, module_id, area_id, title)
|
||
if mod is None:
|
||
raise HTTPException(status_code=404, detail="Paper not found")
|
||
|
||
paper = area[section][idx]
|
||
for field in ("authors", "year", "venue", "arxiv", "pdf", "tags"):
|
||
val = getattr(update, field)
|
||
if val is not None:
|
||
paper[field] = val
|
||
|
||
# Move to different section?
|
||
if update.section and update.section != section:
|
||
if update.section not in ("mainline", "branches", "forward"):
|
||
raise HTTPException(status_code=400, detail="Invalid section")
|
||
area[section].pop(idx)
|
||
area.setdefault(update.section, []).append(paper)
|
||
|
||
save_data(data)
|
||
log.info(f"Updated paper: {title}")
|
||
return {"ok": True, "title": title}
|
||
|
||
@app.delete("/api/papers")
|
||
def delete_paper(
|
||
module_id: str,
|
||
area_id: str,
|
||
title: str,
|
||
_=Depends(verify_api_key),
|
||
):
|
||
"""删除一篇论文"""
|
||
data = load_data()
|
||
mod, area, section, idx = find_paper(data, module_id, area_id, title)
|
||
if mod is None:
|
||
raise HTTPException(status_code=404, detail="Paper not found")
|
||
|
||
area[section].pop(idx)
|
||
save_data(data)
|
||
log.info(f"Deleted paper: {title}")
|
||
return {"ok": True, "title": title}
|
||
|
||
# ─── Routes: PDF proxy ──────────────────────────────────
|
||
@app.get("/papers/arxiv/{arxiv_id}.pdf")
|
||
@app.get("/papers/arxiv/{arxiv_id}")
|
||
def serve_arxiv_pdf(arxiv_id: str):
|
||
"""从本地缓存提供 arXiv PDF(无 .pdf 后缀路由防 IDM 拦截)"""
|
||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||
if not pdf_path.exists():
|
||
raise HTTPException(status_code=404, detail=f"PDF not in local cache: {arxiv_id}")
|
||
return FileResponse(
|
||
pdf_path, media_type="application/pdf",
|
||
headers={"Cache-Control": "public, max-age=86400"},
|
||
)
|
||
|
||
@app.get("/papers/hf/{filename}.pdf")
|
||
@app.get("/papers/hf/{filename}")
|
||
def serve_hf_pdf(filename: str):
|
||
"""从本地缓存提供 HuggingFace PDF(无 .pdf 后缀路由防 IDM 拦截)"""
|
||
safe_name = filename.replace("..", "").replace("/", "_").removesuffix(".pdf")
|
||
pdf_path = PAPERS_DIR / "hf" / f"{safe_name}.pdf"
|
||
if not pdf_path.exists():
|
||
raise HTTPException(status_code=404, detail=f"PDF not in local cache: {filename}")
|
||
return FileResponse(
|
||
pdf_path, media_type="application/pdf",
|
||
headers={"Cache-Control": "public, max-age=86400"},
|
||
)
|
||
|
||
# ─── Routes: Translation ───────────────────────────────
|
||
TRANSLATE_CACHE = ROOT / "data" / "translations"
|
||
TRANSLATE_CACHE.mkdir(parents=True, exist_ok=True)
|
||
|
||
def extract_pdf_text_with_pages(pdf_path: Path, max_chars: int = 12000) -> list[dict]:
|
||
"""从 PDF 提取文本和页码信息,使用 pdftotext (Poppler) 避免 PyMuPDF GPU 依赖"""
|
||
import subprocess, tempfile
|
||
|
||
# Strip arXiv stamp (first page header)
|
||
stamp = f"{pdf_path.stem}.pdf" # e.g. "1706.03762.pdf"
|
||
|
||
result = subprocess.run(
|
||
["pdftotext", "-layout", "-q", str(pdf_path), "-"],
|
||
capture_output=True, text=True, timeout=30
|
||
)
|
||
|
||
if result.returncode != 0:
|
||
log.error(f"pdftotext failed: {result.stderr}")
|
||
raise HTTPException(status_code=500, detail="PDF text extraction failed")
|
||
|
||
text = result.stdout
|
||
|
||
# Remove arXiv stamp line
|
||
import re
|
||
text = re.sub(r'arXiv:' + re.escape(stamp.split('.pdf')[0]) + r'.*?\n\n', '', text, flags=re.DOTALL)
|
||
text = re.sub(r'arXiv:' + re.escape(stamp) + r'.*?\n\n', '', text, flags=re.DOTALL)
|
||
|
||
# Split by form-feed (page break)
|
||
pages = text.split('\f')
|
||
|
||
result_pages = []
|
||
total = 0
|
||
for i, page_text in enumerate(pages):
|
||
pt = page_text.strip()
|
||
if not pt: continue
|
||
result_pages.append({"page": i + 1, "text": pt})
|
||
total += len(pt)
|
||
if total >= max_chars: break
|
||
|
||
if not result_pages:
|
||
raise HTTPException(status_code=500, detail="No text extracted from PDF")
|
||
|
||
return result_pages
|
||
|
||
|
||
def split_text_with_pages(page_texts: list[dict], max_len: int = 400) -> list[dict]:
|
||
"""将按页拆分的文本进一步拆为段落,保留页码"""
|
||
chunks = []
|
||
for pt in page_texts:
|
||
page = pt["page"]
|
||
text = pt["text"]
|
||
raw_paras = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||
for para in raw_paras:
|
||
if len(para) <= max_len:
|
||
chunks.append({"page": page, "text": para})
|
||
else:
|
||
sentences = para.replace(". ", ".|").replace("? ", "?|").replace("! ", "!|").split("|")
|
||
current = ""
|
||
for s in sentences:
|
||
s = s.strip()
|
||
if not s: continue
|
||
if len(current) + len(s) + 1 <= max_len:
|
||
current = (current + " " + s).strip()
|
||
else:
|
||
if current: chunks.append({"page": page, "text": current})
|
||
current = s
|
||
if current: chunks.append({"page": page, "text": current})
|
||
return chunks
|
||
|
||
|
||
def translate_text(text: str, source: str = "en", target: str = "zh") -> str:
|
||
"""使用 MyMemory 免费 API 翻译文本"""
|
||
import urllib.request
|
||
import urllib.parse
|
||
|
||
url = "https://api.mymemory.translated.net/get"
|
||
params = urllib.parse.urlencode({
|
||
"q": text,
|
||
"langpair": f"{source}|{target}",
|
||
"mt": "1", # Force machine translation, not memory
|
||
"de": "me@llm-library.local",
|
||
})
|
||
full_url = f"{url}?{params}"
|
||
|
||
try:
|
||
with urllib.request.urlopen(full_url, timeout=15) as resp:
|
||
data = json.loads(resp.read())
|
||
except Exception as e:
|
||
log.warning(f"Translation API error: {e}")
|
||
return text
|
||
|
||
if data.get("responseStatus") == 200 and data.get("responseData"):
|
||
return data["responseData"]["translatedText"]
|
||
return text
|
||
|
||
|
||
@app.get("/api/translate/{arxiv_id}")
|
||
def translate_paper(arxiv_id: str):
|
||
"""翻译论文正文 (从本地 PDF 提取文本,每段带页码)"""
|
||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||
if not pdf_path.exists():
|
||
raise HTTPException(status_code=404, detail=f"PDF not cached: {arxiv_id}")
|
||
|
||
cache_file = TRANSLATE_CACHE / f"{arxiv_id}.json"
|
||
if cache_file.exists():
|
||
with open(cache_file) as f:
|
||
return json.load(f)
|
||
|
||
# Extract text with page numbers
|
||
log.info(f"Extracting text from {arxiv_id}")
|
||
page_texts = extract_pdf_text_with_pages(pdf_path)
|
||
chunks = split_text_with_pages(page_texts)
|
||
log.info(f"Translating {len(chunks)} paragraphs for {arxiv_id}")
|
||
|
||
translated = []
|
||
for i, chunk in enumerate(chunks):
|
||
if i % 10 == 0:
|
||
log.info(f" [{arxiv_id}] translating paragraph {i+1}/{len(chunks)}")
|
||
zh = translate_text(chunk["text"])
|
||
translated.append({
|
||
"page": chunk["page"],
|
||
"en": chunk["text"],
|
||
"zh": zh,
|
||
})
|
||
|
||
result = {"arxiv_id": arxiv_id, "paragraphs": translated, "count": len(translated)}
|
||
|
||
with open(cache_file, "w") as f:
|
||
json.dump(result, f, ensure_ascii=False)
|
||
|
||
return result
|
||
|
||
|
||
@app.get("/api/translate/{arxiv_id}/status")
|
||
def translate_status(arxiv_id: str):
|
||
"""检查翻译缓存状态"""
|
||
cache_file = TRANSLATE_CACHE / f"{arxiv_id}.json"
|
||
return {
|
||
"arxiv_id": arxiv_id,
|
||
"cached": cache_file.exists(),
|
||
"pdf_exists": (PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf").exists(),
|
||
}
|
||
|
||
# ─── Routes: PDF download on-demand ────────────────────
|
||
@app.post("/api/download/{arxiv_id}")
|
||
def download_single_pdf(arxiv_id: str):
|
||
"""按需下载单篇 arXiv PDF"""
|
||
import subprocess, sys
|
||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||
if pdf_path.exists():
|
||
return {"ok": True, "arxiv_id": arxiv_id, "status": "cached"}
|
||
|
||
cmd = [sys.executable, str(ROOT / "api" / "downloader.py"), "--limit", "1", "--delay", "0"]
|
||
# We need a way to download specific arxiv IDs — for now, just run the downloader
|
||
# It will try all uncached papers, but the specific one will be among them
|
||
try:
|
||
subprocess.run(cmd, cwd=str(ROOT), timeout=60, capture_output=True)
|
||
if pdf_path.exists():
|
||
return {"ok": True, "arxiv_id": arxiv_id, "status": "downloaded"}
|
||
return {"ok": False, "arxiv_id": arxiv_id, "status": "failed"}
|
||
except subprocess.TimeoutExpired:
|
||
return {"ok": False, "arxiv_id": arxiv_id, "status": "timeout"}
|
||
|
||
# ─── Routes: Translated PDFs ────────────────────────────
|
||
TRANSLATED_DIR = PAPERS_DIR / "translated"
|
||
|
||
@app.get("/api/translated/{arxiv_id}")
|
||
def check_translation(arxiv_id: str):
|
||
"""Check if translation exists for a paper"""
|
||
fn = f"{arxiv_id}.pdf"
|
||
return {"arxiv_id": arxiv_id, "exists": (TRANSLATED_DIR / fn).exists()}
|
||
|
||
@app.get("/papers/translated/{arxiv_id}.pdf")
|
||
def serve_translated(arxiv_id: str):
|
||
"""Serve translated PDF from cache"""
|
||
fp = TRANSLATED_DIR / f"{arxiv_id}.pdf"
|
||
if not fp.exists():
|
||
raise HTTPException(status_code=404, detail="Translation not found")
|
||
return FileResponse(fp, media_type="application/pdf",
|
||
headers={"Content-Disposition": "inline"})
|
||
|
||
# ─── Routes: Trigger translation ───────────────────────
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import subprocess, threading
|
||
|
||
_translate_lock = threading.Lock()
|
||
_translating = set()
|
||
|
||
@app.post("/api/translate/{arxiv_id}")
|
||
async def trigger_translation(arxiv_id: str):
|
||
"""Trigger pdf2zh translation for a paper (DeepSeek V4 Flash)"""
|
||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||
if not pdf_path.exists():
|
||
raise HTTPException(status_code=404, detail="PDF not found")
|
||
|
||
out_path = TRANSLATED_DIR / f"{arxiv_id}.pdf"
|
||
if out_path.exists():
|
||
return {"arxiv_id": arxiv_id, "status": "already_translated"}
|
||
|
||
if arxiv_id in _translating:
|
||
return {"arxiv_id": arxiv_id, "status": "in_progress"}
|
||
|
||
def do_translate():
|
||
try:
|
||
_translating.add(arxiv_id)
|
||
from pdf2zh.doclayout import OnnxModel
|
||
from pdf2zh.high_level import translate
|
||
model = OnnxModel.from_pretrained()
|
||
translate(
|
||
[str(pdf_path)], output=str(TRANSLATED_DIR),
|
||
lang_in='en', lang_out='zh',
|
||
service='deepseek', thread=4, model=model,
|
||
)
|
||
mono = TRANSLATED_DIR / f"{arxiv_id}-mono.pdf"
|
||
dual = TRANSLATED_DIR / f"{arxiv_id}-dual.pdf"
|
||
if mono.exists():
|
||
if out_path.exists():
|
||
out_path.unlink()
|
||
mono.rename(out_path)
|
||
if dual.exists():
|
||
dual.unlink()
|
||
log.info(f"Translated: {arxiv_id}")
|
||
except Exception as e:
|
||
log.error(f"Translation failed for {arxiv_id}: {e}")
|
||
finally:
|
||
_translating.discard(arxiv_id)
|
||
|
||
ThreadPoolExecutor(max_workers=1).submit(do_translate)
|
||
return {"arxiv_id": arxiv_id, "status": "started"}
|
||
|
||
@app.get("/api/translate/status")
|
||
def translation_status():
|
||
return {"translating": list(_translating)}
|
||
|
||
# ─── Health ─────────────────────────────────────────────
|
||
@app.get("/api/health")
|
||
def health():
|
||
return {"status": "ok", "version": "0.1.0"}
|
||
|
||
# ─── Mount static frontend (at /) ──────────────────────
|
||
# Static files mounted after API routes to avoid conflicts
|
||
static_dir = ROOT / "static"
|
||
if static_dir.exists() and any(static_dir.iterdir()):
|
||
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
|
||
|
||
# ─── Main ───────────────────────────────────────────────
|
||
def main():
|
||
import uvicorn
|
||
uvicorn.run("api.server:app", host="0.0.0.0", port=8000, reload=True)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|