247 lines
10 KiB
Python
247 lines
10 KiB
Python
"""Download task routes."""
|
|
import uuid
|
|
import os
|
|
import re
|
|
import logging
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks, Request
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from app.schemas import DownloadRequest, DownloadResponse, TaskStatus
|
|
from app.database import get_db, async_session
|
|
from app.models import Video, DownloadLog
|
|
from app.auth import get_current_user, optional_auth
|
|
from app.services.downloader import (
|
|
download_video, get_video_path, detect_platform,
|
|
register_task, get_progress, request_cancel, cleanup_task,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api", tags=["download"])
|
|
|
|
|
|
# ── UA parsing ──────────────────────────────────────────────────────────────
|
|
|
|
def _parse_ua(ua: str) -> tuple[str, str]:
|
|
"""Return (browser, device) from User-Agent string."""
|
|
ua_lower = ua.lower()
|
|
|
|
# Device
|
|
if any(k in ua_lower for k in ("bot", "crawler", "spider", "slurp", "curl", "wget", "python", "axios")):
|
|
device = "bot"
|
|
elif "tablet" in ua_lower or "ipad" in ua_lower:
|
|
device = "tablet"
|
|
elif any(k in ua_lower for k in ("mobile", "android", "iphone", "ipod", "windows phone")):
|
|
device = "mobile"
|
|
else:
|
|
device = "desktop"
|
|
|
|
# Browser
|
|
if "edg/" in ua_lower or "edghtml" in ua_lower:
|
|
browser = "Edge"
|
|
elif "opr/" in ua_lower or "opera" in ua_lower:
|
|
browser = "Opera"
|
|
elif "samsungbrowser" in ua_lower:
|
|
browser = "Samsung"
|
|
elif "chrome/" in ua_lower:
|
|
browser = "Chrome"
|
|
elif "firefox/" in ua_lower:
|
|
browser = "Firefox"
|
|
elif "safari/" in ua_lower:
|
|
browser = "Safari"
|
|
else:
|
|
m = re.search(r"(\w+)/[\d.]+$", ua)
|
|
browser = m.group(1).capitalize() if m else "Unknown"
|
|
|
|
return browser, device
|
|
|
|
|
|
def _client_ip(request: Request) -> str:
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
if request.client:
|
|
return request.client.host
|
|
return ""
|
|
|
|
|
|
async def _geo_lookup(ip: str) -> tuple[str, str, str]:
|
|
"""Return (country_code, country, city) via ip-api.com. Falls back to empty strings."""
|
|
if not ip or ip in ("127.0.0.1", "::1"):
|
|
return "", "", ""
|
|
try:
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=5) as client:
|
|
res = await client.get(
|
|
f"http://ip-api.com/json/{ip}",
|
|
params={"fields": "status,countryCode,country,city"},
|
|
)
|
|
data = res.json()
|
|
if data.get("status") == "success":
|
|
return data.get("countryCode", ""), data.get("country", ""), data.get("city", "")
|
|
except Exception as e:
|
|
logger.debug(f"Geo lookup failed for {ip}: {e}")
|
|
return "", "", ""
|
|
|
|
|
|
async def _log_download(video_id: int, request: Request):
|
|
"""Write a DownloadLog entry with geo info (fire-and-forget)."""
|
|
try:
|
|
ua = request.headers.get("user-agent", "")
|
|
browser, device = _parse_ua(ua)
|
|
ip = _client_ip(request)
|
|
country_code, country, city = await _geo_lookup(ip)
|
|
async with async_session() as db:
|
|
db.add(DownloadLog(
|
|
video_id=video_id,
|
|
ip=ip,
|
|
user_agent=ua[:512],
|
|
browser=browser,
|
|
device=device,
|
|
country_code=country_code,
|
|
country=country,
|
|
city=city,
|
|
downloaded_at=datetime.utcnow(),
|
|
))
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log download: {e}")
|
|
|
|
|
|
async def _do_download(task_id: str, url: str, format_id: str):
|
|
"""Background download task with real-time progress and cancel support."""
|
|
from app.database import async_session
|
|
async with async_session() as db:
|
|
video = (await db.execute(select(Video).where(Video.task_id == task_id))).scalar_one_or_none()
|
|
if not video:
|
|
return
|
|
try:
|
|
video.status = "downloading"
|
|
await db.commit()
|
|
|
|
register_task(task_id)
|
|
result = download_video(url, format_id, task_id=task_id)
|
|
|
|
video.title = result["title"]
|
|
video.thumbnail = result["thumbnail"]
|
|
video.duration = result["duration"]
|
|
video.filename = result["filename"]
|
|
video.file_path = result["file_path"]
|
|
video.file_size = result["file_size"]
|
|
video.platform = result["platform"]
|
|
video.status = "done"
|
|
video.progress = 100
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.error(f"Download failed for {task_id}: {e}")
|
|
is_cancel = "Cancelled" in str(e) or "DownloadCancelled" in type(e).__name__
|
|
video.status = "error"
|
|
video.error_message = "下载已取消,请重试" if is_cancel else str(e)[:500]
|
|
video.progress = 0
|
|
await db.commit()
|
|
finally:
|
|
cleanup_task(task_id)
|
|
|
|
|
|
@router.post("/download", response_model=DownloadResponse)
|
|
async def start_download(req: DownloadRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db)):
|
|
# Dedup: reuse existing completed download if file still on disk
|
|
existing = (await db.execute(
|
|
select(Video).where(
|
|
Video.url == req.url,
|
|
Video.format_id == req.format_id,
|
|
Video.status == "done",
|
|
).order_by(Video.created_at.desc()).limit(1)
|
|
)).scalar_one_or_none()
|
|
|
|
if existing and os.path.exists(existing.file_path):
|
|
logger.info(f"Reusing existing download task_id={existing.task_id} for url={req.url} format={req.format_id}")
|
|
return DownloadResponse(task_id=existing.task_id, status="done")
|
|
|
|
task_id = str(uuid.uuid4())[:8]
|
|
video = Video(task_id=task_id, url=req.url, quality=req.quality, format_id=req.format_id,
|
|
status="pending", platform=detect_platform(req.url))
|
|
db.add(video)
|
|
await db.commit()
|
|
background_tasks.add_task(_do_download, task_id, req.url, req.format_id)
|
|
return DownloadResponse(task_id=task_id, status="pending")
|
|
|
|
|
|
@router.get("/download/{task_id}", response_model=TaskStatus)
|
|
async def get_download_status(task_id: str, db: AsyncSession = Depends(get_db)):
|
|
video = (await db.execute(select(Video).where(Video.task_id == task_id))).scalar_one_or_none()
|
|
if not video:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
# Inject real-time progress for active downloads
|
|
progress = get_progress(task_id) if video.status == "downloading" else video.progress
|
|
return TaskStatus(
|
|
task_id=video.task_id,
|
|
status=video.status,
|
|
progress=progress,
|
|
title=video.title,
|
|
error_message=video.error_message or "",
|
|
video_id=video.id if video.status == "done" else None,
|
|
)
|
|
|
|
|
|
@router.post("/download/{task_id}/cancel")
|
|
async def cancel_download(task_id: str, db: AsyncSession = Depends(get_db)):
|
|
video = (await db.execute(select(Video).where(Video.task_id == task_id))).scalar_one_or_none()
|
|
if not video:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
if video.status != "downloading":
|
|
raise HTTPException(status_code=400, detail="Task is not downloading")
|
|
request_cancel(task_id)
|
|
return {"ok": True, "message": "Cancel requested"}
|
|
|
|
|
|
@router.get("/file/{video_id}")
|
|
async def download_file(video_id: int, request: Request, background_tasks: BackgroundTasks, user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
|
video = (await db.execute(select(Video).where(Video.id == video_id))).scalar_one_or_none()
|
|
if not video or video.status != "done":
|
|
raise HTTPException(status_code=404, detail="Video not found")
|
|
if not os.path.exists(video.file_path):
|
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
|
background_tasks.add_task(_log_download, video.id, request)
|
|
return FileResponse(video.file_path, filename=video.filename, media_type="video/mp4")
|
|
|
|
|
|
@router.get("/stream/{video_id}")
|
|
async def stream_video(video_id: int, request: Request, background_tasks: BackgroundTasks, token: str = None, user: dict = Depends(optional_auth), db: AsyncSession = Depends(get_db)):
|
|
# Allow token via query param for video player
|
|
if not user and token:
|
|
from app.auth import verify_token
|
|
user = verify_token(token)
|
|
if not user:
|
|
from fastapi import HTTPException, status
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
|
video = (await db.execute(select(Video).where(Video.id == video_id))).scalar_one_or_none()
|
|
if not video or video.status != "done":
|
|
raise HTTPException(status_code=404, detail="Video not found")
|
|
if not os.path.exists(video.file_path):
|
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
|
background_tasks.add_task(_log_download, video.id, request)
|
|
|
|
def iter_file():
|
|
with open(video.file_path, "rb") as f:
|
|
while chunk := f.read(1024 * 1024):
|
|
yield chunk
|
|
|
|
return StreamingResponse(iter_file(), media_type="video/mp4", headers={
|
|
"Content-Disposition": f"inline; filename={video.filename}",
|
|
"Content-Length": str(video.file_size),
|
|
})
|
|
|
|
|
|
@router.get("/file/task/{task_id}")
|
|
async def download_file_by_task(task_id: str, request: Request, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db)):
|
|
"""Download file by task_id - no auth required (public download)."""
|
|
video = (await db.execute(select(Video).where(Video.task_id == task_id))).scalar_one_or_none()
|
|
if not video or video.status != "done":
|
|
raise HTTPException(status_code=404, detail="Video not found")
|
|
if not os.path.exists(video.file_path):
|
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
|
background_tasks.add_task(_log_download, video.id, request)
|
|
return FileResponse(video.file_path, filename=video.filename, media_type="video/mp4")
|