- FastAPI 后端: REST API + Bearer Token 鉴权 + PDF 代理 - 180 篇论文数据 (data/papers.json): 9 模块、32 子领域 - 前端: 数据驱动、卡片径向渐变光效、PDF 页面内阅读 - 底部状态栏: arXiv/HF 连通性检测 - PDF 加载: arXiv 优先(5s超时) → HK 本地兜底 - Docker 化部署 (Dockerfile + start.sh + nginx.conf) - arXiv + HF 批量下载器 (api/downloader.py)
199 lines
6.5 KiB
Python
199 lines
6.5 KiB
Python
"""
|
|
LLM 论文图书馆 — PDF 下载器
|
|
从 arXiv 和 HuggingFace 下载论文 PDF 到本地缓存
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
from tqdm import tqdm
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
DATA_FILE = ROOT / "data" / "papers.json"
|
|
ARXIV_DIR = ROOT / "papers" / "arxiv"
|
|
HF_DIR = ROOT / "papers" / "hf"
|
|
LOG_FILE = ROOT / "papers" / "download.log"
|
|
|
|
log = logging.getLogger("downloader")
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(LOG_FILE),
|
|
logging.StreamHandler(),
|
|
],
|
|
)
|
|
|
|
|
|
def collect_urls() -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
|
"""从 papers.json 收集所有需要下载的 PDF URL
|
|
|
|
Returns:
|
|
arxiv_list: [(arxiv_id, title), ...]
|
|
hf_list: [(url, filename), ...]
|
|
"""
|
|
with open(DATA_FILE) as f:
|
|
data = json.load(f)
|
|
|
|
arxiv_seen = set()
|
|
hf_seen = set()
|
|
arxiv_list = []
|
|
hf_list = []
|
|
|
|
for mod in data.values():
|
|
for area in mod.get("areas", []):
|
|
for section in ("mainline", "branches", "forward"):
|
|
for p in area.get(section, []):
|
|
if p.get("arxiv") and p["arxiv"] not in arxiv_seen:
|
|
arxiv_seen.add(p["arxiv"])
|
|
arxiv_list.append((p["arxiv"], p.get("title", "")))
|
|
if p.get("pdf") and p["pdf"] not in hf_seen:
|
|
hf_seen.add(p["pdf"])
|
|
# Derive a safe filename from the URL
|
|
name = p["pdf"].split("/")[-1].replace(".pdf", "")
|
|
hf_list.append((p["pdf"], name))
|
|
|
|
return arxiv_list, hf_list
|
|
|
|
|
|
def download_arxiv(client: httpx.Client, arxiv_id: str, title: str) -> bool:
|
|
"""下载单个 arXiv PDF"""
|
|
pdf_path = ARXIV_DIR / f"{arxiv_id}.pdf"
|
|
if pdf_path.exists():
|
|
log.debug(f"Skip (exists): {arxiv_id}")
|
|
return True
|
|
|
|
url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
|
|
try:
|
|
resp = client.get(url, follow_redirects=True, timeout=30)
|
|
resp.raise_for_status()
|
|
|
|
# Verify it's actually a PDF (arxiv returns HTML for missing papers)
|
|
content_type = resp.headers.get("content-type", "")
|
|
if "pdf" not in content_type and not resp.content.startswith(b"%PDF"):
|
|
log.warning(f"Not a PDF: {arxiv_id} — {title[:60]}")
|
|
return False
|
|
|
|
pdf_path.write_bytes(resp.content)
|
|
size_kb = len(resp.content) / 1024
|
|
log.info(f"OK: {arxiv_id} ({size_kb:.0f} KB) — {title[:60]}")
|
|
return True
|
|
except httpx.HTTPError as e:
|
|
log.error(f"HTTP error {arxiv_id}: {e}")
|
|
return False
|
|
except Exception as e:
|
|
log.error(f"Error {arxiv_id}: {e}")
|
|
return False
|
|
|
|
|
|
def download_hf(client: httpx.Client, url: str, filename: str) -> bool:
|
|
"""下载单个 HuggingFace PDF"""
|
|
safe_name = filename.replace("..", "").replace("/", "_")
|
|
pdf_path = HF_DIR / f"{safe_name}.pdf"
|
|
if pdf_path.exists():
|
|
log.debug(f"Skip (exists): {safe_name}")
|
|
return True
|
|
|
|
try:
|
|
resp = client.get(url, follow_redirects=True, timeout=60)
|
|
resp.raise_for_status()
|
|
|
|
if not resp.content.startswith(b"%PDF"):
|
|
log.warning(f"Not a PDF: {safe_name}")
|
|
return False
|
|
|
|
pdf_path.write_bytes(resp.content)
|
|
size_kb = len(resp.content) / 1024
|
|
log.info(f"OK (HF): {safe_name} ({size_kb:.0f} KB)")
|
|
return True
|
|
except httpx.HTTPError as e:
|
|
log.error(f"HTTP error {safe_name}: {e}")
|
|
return False
|
|
except Exception as e:
|
|
log.error(f"Error {safe_name}: {e}")
|
|
return False
|
|
|
|
|
|
def run(incremental: bool = True, limit: int = 0, delay: float = 1.0):
|
|
"""批量下载所有 PDF
|
|
|
|
Args:
|
|
incremental: True=跳过已有文件
|
|
limit: 0=全部, N=只下载前N篇
|
|
delay: 请求间延迟(秒)
|
|
"""
|
|
ARXIV_DIR.mkdir(parents=True, exist_ok=True)
|
|
HF_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
arxiv_list, hf_list = collect_urls()
|
|
total = len(arxiv_list) + len(hf_list)
|
|
log.info(f"Found {len(arxiv_list)} arXiv + {len(hf_list)} HF = {total} PDFs to download")
|
|
log.info(f"Incremental: {incremental}, Delay: {delay}s")
|
|
|
|
if not incremental:
|
|
log.warning("Non-incremental mode: will re-download existing files")
|
|
|
|
# Count existing
|
|
arxiv_existing = sum(1 for aid, _ in arxiv_list if (ARXIV_DIR / f"{aid}.pdf").exists())
|
|
hf_existing = sum(1 for _, name in hf_list if (HF_DIR / f"{name}.pdf").exists())
|
|
log.info(f"Already cached: {arxiv_existing} arXiv + {hf_existing} HF")
|
|
|
|
ok, fail = 0, 0
|
|
total_size = 0.0
|
|
|
|
with httpx.Client(
|
|
headers={"User-Agent": "LLM-Library-Downloader/0.1"},
|
|
timeout=30,
|
|
follow_redirects=True,
|
|
) as client:
|
|
|
|
# Download arXiv
|
|
if limit > 0:
|
|
arxiv_list = arxiv_list[:limit]
|
|
for arxiv_id, title in tqdm(arxiv_list, desc="arXiv"):
|
|
if incremental and (ARXIV_DIR / f"{arxiv_id}.pdf").exists():
|
|
ok += 1
|
|
continue
|
|
success = download_arxiv(client, arxiv_id, title)
|
|
if success:
|
|
ok += 1
|
|
p = ARXIV_DIR / f"{arxiv_id}.pdf"
|
|
if p.exists():
|
|
total_size += p.stat().st_size
|
|
else:
|
|
fail += 1
|
|
time.sleep(delay)
|
|
|
|
# Download HF
|
|
if limit > 0:
|
|
hf_list = hf_list[:limit]
|
|
for url, name in tqdm(hf_list, desc="HF "):
|
|
if incremental and (HF_DIR / f"{name}.pdf").exists():
|
|
ok += 1
|
|
continue
|
|
success = download_hf(client, url, name)
|
|
if success:
|
|
ok += 1
|
|
p = HF_DIR / f"{name}.pdf"
|
|
if p.exists():
|
|
total_size += p.stat().st_size
|
|
else:
|
|
fail += 1
|
|
time.sleep(delay)
|
|
|
|
log.info(f"Done: {ok} OK, {fail} failed, {total_size/1024/1024:.1f} MB total")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="下载论文 PDF 到本地缓存")
|
|
parser.add_argument("--no-incremental", action="store_true", help="重新下载所有 (默认跳过已有)")
|
|
parser.add_argument("--limit", type=int, default=0, help="限制下载数量 (0=全部)")
|
|
parser.add_argument("--delay", type=float, default=1.0, help="请求间延迟 (秒)")
|
|
args = parser.parse_args()
|
|
run(incremental=not args.no_incremental, limit=args.limit, delay=args.delay)
|