feat: LLM 论文图书馆 — 初始提交
- FastAPI 后端: REST API + Bearer Token 鉴权 + PDF 代理 - 180 篇论文数据 (data/papers.json): 9 模块、32 子领域 - 前端: 数据驱动、卡片径向渐变光效、PDF 页面内阅读 - 底部状态栏: arXiv/HF 连通性检测 - PDF 加载: arXiv 优先(5s超时) → HK 本地兜底 - Docker 化部署 (Dockerfile + start.sh + nginx.conf) - arXiv + HF 批量下载器 (api/downloader.py)
This commit is contained in:
484
api/server.py
Normal file
484
api/server.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
LLM 论文图书馆 — FastAPI 后端
|
||||
提供 REST API 进行论文查询、管理、PDF 代理服务
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
# ─── Config ────────────────────────────────────────────
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
DATA_FILE = ROOT / "data" / "papers.json"
|
||||
PAPERS_DIR = ROOT / "papers"
|
||||
API_KEY = os.environ.get("LLM_LIB_API_KEY", "change-me")
|
||||
|
||||
log = logging.getLogger("llm-library")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
# ─── App ───────────────────────────────────────────────
|
||||
app = FastAPI(
|
||||
title="LLM 论文图书馆",
|
||||
description="大模型论文知识库 API — 查询、搜索、管理论文",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ─── Auth ──────────────────────────────────────────────
|
||||
def verify_api_key(request: Request):
|
||||
"""简单的 API Key 鉴权 — 用于写操作 (POST/PUT/DELETE)"""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
token = auth[7:]
|
||||
else:
|
||||
token = request.query_params.get("api_key", "")
|
||||
if not token or token != API_KEY:
|
||||
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
||||
return True
|
||||
|
||||
# ─── Data loading ──────────────────────────────────────
|
||||
def load_data():
|
||||
if not DATA_FILE.exists():
|
||||
return {}
|
||||
with open(DATA_FILE, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_data(data):
|
||||
with open(DATA_FILE, 'w') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# ─── Paper CRUD helpers ────────────────────────────────
|
||||
def find_paper(data, module_id, area_id, title):
|
||||
"""Find a paper index by title within module/area"""
|
||||
mod = data.get(module_id)
|
||||
if not mod:
|
||||
return None, None, None, None
|
||||
for area in mod.get("areas", []):
|
||||
if area["id"] == area_id:
|
||||
for section in ("mainline", "branches", "forward"):
|
||||
for i, p in enumerate(area.get(section, [])):
|
||||
if p["title"] == title:
|
||||
return mod, area, section, i
|
||||
return None, None, None, None
|
||||
|
||||
# ─── Routes: Query ─────────────────────────────────────
|
||||
@app.get("/api/stats")
|
||||
def get_stats():
|
||||
"""获取图书馆统计信息"""
|
||||
data = load_data()
|
||||
mods = len(data)
|
||||
areas = 0
|
||||
papers = 0
|
||||
sections = {"mainline": 0, "branches": 0, "forward": 0}
|
||||
for mod in data.values():
|
||||
areas += len(mod.get("areas", []))
|
||||
for area in mod.get("areas", []):
|
||||
for s in ("mainline", "branches", "forward"):
|
||||
n = len(area.get(s, []))
|
||||
papers += n
|
||||
sections[s] += n
|
||||
return {
|
||||
"modules": mods,
|
||||
"areas": areas,
|
||||
"papers": papers,
|
||||
"sections": sections,
|
||||
"data_file": str(DATA_FILE),
|
||||
}
|
||||
|
||||
@app.get("/api/modules")
|
||||
def list_modules():
|
||||
"""列出所有模块 (不含论文详情)"""
|
||||
data = load_data()
|
||||
return [
|
||||
{
|
||||
"id": mid,
|
||||
"name": m["name"],
|
||||
"icon": m["icon"],
|
||||
"desc": m["desc"],
|
||||
"area_count": len(m.get("areas", [])),
|
||||
"paper_count": sum(
|
||||
len(a.get("mainline", [])) + len(a.get("branches", [])) + len(a.get("forward", []))
|
||||
for a in m.get("areas", [])
|
||||
),
|
||||
}
|
||||
for mid, m in data.items()
|
||||
]
|
||||
|
||||
@app.get("/api/modules/{module_id}")
|
||||
def get_module(module_id: str):
|
||||
"""获取单个模块的完整论文数据"""
|
||||
data = load_data()
|
||||
mod = data.get(module_id)
|
||||
if not mod:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_id}' not found")
|
||||
return mod
|
||||
|
||||
@app.get("/api/papers")
|
||||
def search_papers(
|
||||
q: str = Query(default="", description="搜索关键词: 标题/作者"),
|
||||
module: Optional[str] = Query(default=None),
|
||||
tag: Optional[str] = Query(default=None, description="起点/关键节点/前沿/前瞻/支线"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
):
|
||||
"""搜索论文 (全文/按模块/按标签)"""
|
||||
data = load_data()
|
||||
results = []
|
||||
q = q.lower()
|
||||
for mid, mod in data.items():
|
||||
if module and mid != module:
|
||||
continue
|
||||
for area in mod.get("areas", []):
|
||||
for section in ("mainline", "branches", "forward"):
|
||||
for p in area.get(section, []):
|
||||
# Filter by tag
|
||||
if tag and tag not in p.get("tags", []):
|
||||
continue
|
||||
# Filter by query
|
||||
if q:
|
||||
if q not in (p.get("title", "") + p.get("authors", "")).lower():
|
||||
continue
|
||||
results.append({
|
||||
"module_id": mid,
|
||||
"module_name": mod["name"],
|
||||
"area_id": area["id"],
|
||||
"area_name": area["name"],
|
||||
"section": section,
|
||||
**p,
|
||||
})
|
||||
if len(results) >= limit:
|
||||
break
|
||||
if len(results) >= limit:
|
||||
break
|
||||
if len(results) >= limit:
|
||||
break
|
||||
if len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
# ─── Routes: Management (写操作, 需 API Key) ────────────
|
||||
class PaperCreate(BaseModel):
|
||||
module_id: str
|
||||
area_id: str
|
||||
section: str = "mainline" # mainline / branches / forward
|
||||
title: str
|
||||
authors: str = ""
|
||||
year: int
|
||||
venue: str = ""
|
||||
arxiv: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
tags: list[str] = []
|
||||
|
||||
class PaperUpdate(BaseModel):
|
||||
authors: Optional[str] = None
|
||||
year: Optional[int] = None
|
||||
venue: Optional[str] = None
|
||||
arxiv: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
section: Optional[str] = None # move to different section
|
||||
|
||||
@app.post("/api/papers", dependencies=[Depends(verify_api_key)])
|
||||
def add_paper(paper: PaperCreate):
|
||||
"""添加一篇新论文"""
|
||||
data = load_data()
|
||||
mod = data.get(paper.module_id)
|
||||
if not mod:
|
||||
raise HTTPException(status_code=404, detail="Module not found")
|
||||
|
||||
area = next((a for a in mod["areas"] if a["id"] == paper.area_id), None)
|
||||
if not area:
|
||||
raise HTTPException(status_code=404, detail="Area not found")
|
||||
|
||||
section = paper.section
|
||||
if section not in ("mainline", "branches", "forward"):
|
||||
raise HTTPException(status_code=400, detail="section must be mainline/branches/forward")
|
||||
|
||||
entry = {
|
||||
"title": paper.title,
|
||||
"authors": paper.authors,
|
||||
"year": paper.year,
|
||||
"venue": paper.venue,
|
||||
"tags": paper.tags,
|
||||
}
|
||||
if paper.arxiv:
|
||||
entry["arxiv"] = paper.arxiv
|
||||
if paper.pdf:
|
||||
entry["pdf"] = paper.pdf
|
||||
|
||||
area.setdefault(section, []).append(entry)
|
||||
save_data(data)
|
||||
log.info(f"Added paper: {paper.title}")
|
||||
return {"ok": True, "title": paper.title}
|
||||
|
||||
@app.put("/api/papers")
|
||||
def update_paper(
|
||||
module_id: str,
|
||||
area_id: str,
|
||||
title: str,
|
||||
update: PaperUpdate,
|
||||
_=Depends(verify_api_key),
|
||||
):
|
||||
"""更新一篇论文"""
|
||||
data = load_data()
|
||||
mod, area, section, idx = find_paper(data, module_id, area_id, title)
|
||||
if mod is None:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
|
||||
paper = area[section][idx]
|
||||
for field in ("authors", "year", "venue", "arxiv", "pdf", "tags"):
|
||||
val = getattr(update, field)
|
||||
if val is not None:
|
||||
paper[field] = val
|
||||
|
||||
# Move to different section?
|
||||
if update.section and update.section != section:
|
||||
if update.section not in ("mainline", "branches", "forward"):
|
||||
raise HTTPException(status_code=400, detail="Invalid section")
|
||||
area[section].pop(idx)
|
||||
area.setdefault(update.section, []).append(paper)
|
||||
|
||||
save_data(data)
|
||||
log.info(f"Updated paper: {title}")
|
||||
return {"ok": True, "title": title}
|
||||
|
||||
@app.delete("/api/papers")
|
||||
def delete_paper(
|
||||
module_id: str,
|
||||
area_id: str,
|
||||
title: str,
|
||||
_=Depends(verify_api_key),
|
||||
):
|
||||
"""删除一篇论文"""
|
||||
data = load_data()
|
||||
mod, area, section, idx = find_paper(data, module_id, area_id, title)
|
||||
if mod is None:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
|
||||
area[section].pop(idx)
|
||||
save_data(data)
|
||||
log.info(f"Deleted paper: {title}")
|
||||
return {"ok": True, "title": title}
|
||||
|
||||
# ─── Routes: PDF proxy ──────────────────────────────────
|
||||
@app.get("/papers/arxiv/{arxiv_id}.pdf")
|
||||
@app.get("/papers/arxiv/{arxiv_id}")
|
||||
def serve_arxiv_pdf(arxiv_id: str):
|
||||
"""从本地缓存提供 arXiv PDF(无 .pdf 后缀路由防 IDM 拦截)"""
|
||||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||||
if not pdf_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"PDF not in local cache: {arxiv_id}")
|
||||
return FileResponse(
|
||||
pdf_path, media_type="application/pdf",
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
|
||||
@app.get("/papers/hf/{filename}.pdf")
|
||||
@app.get("/papers/hf/{filename}")
|
||||
def serve_hf_pdf(filename: str):
|
||||
"""从本地缓存提供 HuggingFace PDF(无 .pdf 后缀路由防 IDM 拦截)"""
|
||||
safe_name = filename.replace("..", "").replace("/", "_").removesuffix(".pdf")
|
||||
pdf_path = PAPERS_DIR / "hf" / f"{safe_name}.pdf"
|
||||
if not pdf_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"PDF not in local cache: {filename}")
|
||||
return FileResponse(
|
||||
pdf_path, media_type="application/pdf",
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
|
||||
# ─── Routes: Translation ───────────────────────────────
|
||||
TRANSLATE_CACHE = ROOT / "data" / "translations"
|
||||
TRANSLATE_CACHE.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def extract_pdf_text_with_pages(pdf_path: Path, max_chars: int = 12000) -> list[dict]:
|
||||
"""从 PDF 提取文本和页码信息,使用 pdftotext (Poppler) 避免 PyMuPDF GPU 依赖"""
|
||||
import subprocess, tempfile
|
||||
|
||||
# Strip arXiv stamp (first page header)
|
||||
stamp = f"{pdf_path.stem}.pdf" # e.g. "1706.03762.pdf"
|
||||
|
||||
result = subprocess.run(
|
||||
["pdftotext", "-layout", "-q", str(pdf_path), "-"],
|
||||
capture_output=True, text=True, timeout=30
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
log.error(f"pdftotext failed: {result.stderr}")
|
||||
raise HTTPException(status_code=500, detail="PDF text extraction failed")
|
||||
|
||||
text = result.stdout
|
||||
|
||||
# Remove arXiv stamp line
|
||||
import re
|
||||
text = re.sub(r'arXiv:' + re.escape(stamp.split('.pdf')[0]) + r'.*?\n\n', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'arXiv:' + re.escape(stamp) + r'.*?\n\n', '', text, flags=re.DOTALL)
|
||||
|
||||
# Split by form-feed (page break)
|
||||
pages = text.split('\f')
|
||||
|
||||
result_pages = []
|
||||
total = 0
|
||||
for i, page_text in enumerate(pages):
|
||||
pt = page_text.strip()
|
||||
if not pt: continue
|
||||
result_pages.append({"page": i + 1, "text": pt})
|
||||
total += len(pt)
|
||||
if total >= max_chars: break
|
||||
|
||||
if not result_pages:
|
||||
raise HTTPException(status_code=500, detail="No text extracted from PDF")
|
||||
|
||||
return result_pages
|
||||
|
||||
|
||||
def split_text_with_pages(page_texts: list[dict], max_len: int = 400) -> list[dict]:
|
||||
"""将按页拆分的文本进一步拆为段落,保留页码"""
|
||||
chunks = []
|
||||
for pt in page_texts:
|
||||
page = pt["page"]
|
||||
text = pt["text"]
|
||||
raw_paras = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
for para in raw_paras:
|
||||
if len(para) <= max_len:
|
||||
chunks.append({"page": page, "text": para})
|
||||
else:
|
||||
sentences = para.replace(". ", ".|").replace("? ", "?|").replace("! ", "!|").split("|")
|
||||
current = ""
|
||||
for s in sentences:
|
||||
s = s.strip()
|
||||
if not s: continue
|
||||
if len(current) + len(s) + 1 <= max_len:
|
||||
current = (current + " " + s).strip()
|
||||
else:
|
||||
if current: chunks.append({"page": page, "text": current})
|
||||
current = s
|
||||
if current: chunks.append({"page": page, "text": current})
|
||||
return chunks
|
||||
|
||||
|
||||
def translate_text(text: str, source: str = "en", target: str = "zh") -> str:
|
||||
"""使用 MyMemory 免费 API 翻译文本"""
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
|
||||
url = "https://api.mymemory.translated.net/get"
|
||||
params = urllib.parse.urlencode({
|
||||
"q": text,
|
||||
"langpair": f"{source}|{target}",
|
||||
"mt": "1", # Force machine translation, not memory
|
||||
"de": "me@llm-library.local",
|
||||
})
|
||||
full_url = f"{url}?{params}"
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(full_url, timeout=15) as resp:
|
||||
data = json.loads(resp.read())
|
||||
except Exception as e:
|
||||
log.warning(f"Translation API error: {e}")
|
||||
return text
|
||||
|
||||
if data.get("responseStatus") == 200 and data.get("responseData"):
|
||||
return data["responseData"]["translatedText"]
|
||||
return text
|
||||
|
||||
|
||||
@app.get("/api/translate/{arxiv_id}")
|
||||
def translate_paper(arxiv_id: str):
|
||||
"""翻译论文正文 (从本地 PDF 提取文本,每段带页码)"""
|
||||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||||
if not pdf_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"PDF not cached: {arxiv_id}")
|
||||
|
||||
cache_file = TRANSLATE_CACHE / f"{arxiv_id}.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file) as f:
|
||||
return json.load(f)
|
||||
|
||||
# Extract text with page numbers
|
||||
log.info(f"Extracting text from {arxiv_id}")
|
||||
page_texts = extract_pdf_text_with_pages(pdf_path)
|
||||
chunks = split_text_with_pages(page_texts)
|
||||
log.info(f"Translating {len(chunks)} paragraphs for {arxiv_id}")
|
||||
|
||||
translated = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i % 10 == 0:
|
||||
log.info(f" [{arxiv_id}] translating paragraph {i+1}/{len(chunks)}")
|
||||
zh = translate_text(chunk["text"])
|
||||
translated.append({
|
||||
"page": chunk["page"],
|
||||
"en": chunk["text"],
|
||||
"zh": zh,
|
||||
})
|
||||
|
||||
result = {"arxiv_id": arxiv_id, "paragraphs": translated, "count": len(translated)}
|
||||
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(result, f, ensure_ascii=False)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/translate/{arxiv_id}/status")
|
||||
def translate_status(arxiv_id: str):
|
||||
"""检查翻译缓存状态"""
|
||||
cache_file = TRANSLATE_CACHE / f"{arxiv_id}.json"
|
||||
return {
|
||||
"arxiv_id": arxiv_id,
|
||||
"cached": cache_file.exists(),
|
||||
"pdf_exists": (PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf").exists(),
|
||||
}
|
||||
|
||||
# ─── Routes: PDF download on-demand ────────────────────
|
||||
@app.post("/api/download/{arxiv_id}")
|
||||
def download_single_pdf(arxiv_id: str):
|
||||
"""按需下载单篇 arXiv PDF"""
|
||||
import subprocess, sys
|
||||
pdf_path = PAPERS_DIR / "arxiv" / f"{arxiv_id}.pdf"
|
||||
if pdf_path.exists():
|
||||
return {"ok": True, "arxiv_id": arxiv_id, "status": "cached"}
|
||||
|
||||
cmd = [sys.executable, str(ROOT / "api" / "downloader.py"), "--limit", "1", "--delay", "0"]
|
||||
# We need a way to download specific arxiv IDs — for now, just run the downloader
|
||||
# It will try all uncached papers, but the specific one will be among them
|
||||
try:
|
||||
subprocess.run(cmd, cwd=str(ROOT), timeout=60, capture_output=True)
|
||||
if pdf_path.exists():
|
||||
return {"ok": True, "arxiv_id": arxiv_id, "status": "downloaded"}
|
||||
return {"ok": False, "arxiv_id": arxiv_id, "status": "failed"}
|
||||
except subprocess.TimeoutExpired:
|
||||
return {"ok": False, "arxiv_id": arxiv_id, "status": "timeout"}
|
||||
|
||||
# ─── Health ─────────────────────────────────────────────
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
return {"status": "ok", "version": "0.1.0"}
|
||||
|
||||
# ─── Mount static frontend (at /) ──────────────────────
|
||||
# Static files mounted after API routes to avoid conflicts
|
||||
static_dir = ROOT / "static"
|
||||
if static_dir.exists() and any(static_dir.iterdir()):
|
||||
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
|
||||
|
||||
# ─── Main ───────────────────────────────────────────────
|
||||
def main():
|
||||
import uvicorn
|
||||
uvicorn.run("api.server:app", host="0.0.0.0", port=8000, reload=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user