242 lines
6.6 KiB
Python
242 lines
6.6 KiB
Python
"""
|
||
Nougat提取服务
|
||
|
||
使用Nougat OCR提取学术PDF的高质量文本
|
||
保留表格、公式等结构信息
|
||
"""
|
||
|
||
import subprocess
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Optional, Callable
|
||
from loguru import logger
|
||
|
||
|
||
def check_nougat_available() -> bool:
|
||
"""
|
||
检查Nougat是否已安装
|
||
|
||
Returns:
|
||
True if Nougat可用
|
||
"""
|
||
try:
|
||
# 方法1: 尝试导入nougat模块
|
||
import nougat
|
||
logger.info(f"Nougat module is available (version: {getattr(nougat, '__version__', 'unknown')})")
|
||
return True
|
||
except ImportError:
|
||
logger.warning("Nougat module not found")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"检查Nougat失败: {str(e)}")
|
||
return False
|
||
|
||
|
||
def extract_pdf_nougat(
|
||
file_path: str,
|
||
output_dir: Optional[str] = None,
|
||
progress_callback: Optional[Callable[[int, int], None]] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
使用Nougat提取PDF文本
|
||
|
||
Args:
|
||
file_path: PDF文件路径
|
||
output_dir: 输出目录,默认为临时目录
|
||
progress_callback: 进度回调函数 (current_page, total_pages)
|
||
|
||
Returns:
|
||
{
|
||
"success": True,
|
||
"method": "nougat",
|
||
"text": "提取的Markdown文本",
|
||
"format": "markdown",
|
||
"metadata": {
|
||
"page_count": 20,
|
||
"char_count": 50000,
|
||
"quality_score": 0.95,
|
||
"has_tables": True,
|
||
"has_formulas": True
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
# 检查Nougat是否可用
|
||
if not check_nougat_available():
|
||
raise Exception("Nougat未安装,请先安装:pip install nougat-ocr")
|
||
|
||
logger.info(f"开始使用Nougat提取: {file_path}")
|
||
|
||
# 准备输出目录
|
||
if output_dir is None:
|
||
output_dir = os.path.join(os.path.dirname(file_path), "nougat_output")
|
||
|
||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 构建Nougat命令
|
||
# nougat命令格式:nougat <pdf_path> -o <output_dir>
|
||
cmd = [
|
||
'nougat',
|
||
file_path,
|
||
'-o', output_dir,
|
||
'--markdown', # 输出Markdown格式
|
||
'--no-skipping' # 不跳过任何页面
|
||
]
|
||
|
||
logger.info(f"执行命令: {' '.join(cmd)}")
|
||
|
||
# 执行Nougat
|
||
# 注意:Nougat可能需要较长时间(1-2分钟/20页)
|
||
process = subprocess.Popen(
|
||
cmd,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
text=True
|
||
)
|
||
|
||
# 等待完成
|
||
stdout, stderr = process.communicate(timeout=300) # 5分钟超时
|
||
|
||
if process.returncode != 0:
|
||
logger.error(f"Nougat执行失败: {stderr}")
|
||
raise Exception(f"Nougat执行失败: {stderr}")
|
||
|
||
# 读取输出文件
|
||
# Nougat会生成 <filename>.mmd 文件
|
||
pdf_name = Path(file_path).stem
|
||
output_file = Path(output_dir) / f"{pdf_name}.mmd"
|
||
|
||
if not output_file.exists():
|
||
raise Exception(f"Nougat输出文件不存在: {output_file}")
|
||
|
||
with open(output_file, 'r', encoding='utf-8') as f:
|
||
markdown_text = f.read()
|
||
|
||
# 评估质量
|
||
quality_result = evaluate_nougat_quality(markdown_text)
|
||
|
||
logger.info(f"Nougat提取完成: 质量={quality_result['quality_score']:.2f}")
|
||
|
||
return {
|
||
"success": True,
|
||
"method": "nougat",
|
||
"text": markdown_text,
|
||
"format": "markdown",
|
||
"metadata": {
|
||
"char_count": len(markdown_text),
|
||
"quality_score": quality_result['quality_score'],
|
||
"has_tables": quality_result['has_tables'],
|
||
"has_formulas": quality_result['has_formulas'],
|
||
"has_structure": quality_result['has_structure']
|
||
}
|
||
}
|
||
|
||
except subprocess.TimeoutExpired:
|
||
logger.error("Nougat处理超时(>5分钟)")
|
||
return {
|
||
"success": False,
|
||
"error": "处理超时",
|
||
"method": "nougat"
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Nougat提取失败: {str(e)}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"method": "nougat"
|
||
}
|
||
|
||
|
||
def evaluate_nougat_quality(text: str) -> Dict[str, Any]:
|
||
"""
|
||
评估Nougat提取质量
|
||
|
||
评分标准:
|
||
- 基础分:0.5
|
||
- 有章节结构:+0.2
|
||
- 有表格:+0.15
|
||
- 有公式:+0.15
|
||
- 文本长度充足:+0.1
|
||
- 乱码检测:-0.3
|
||
|
||
Args:
|
||
text: Nougat提取的Markdown文本
|
||
|
||
Returns:
|
||
{
|
||
"quality_score": 0.92,
|
||
"has_structure": True,
|
||
"has_tables": True,
|
||
"has_formulas": True,
|
||
"has_garbled": False
|
||
}
|
||
"""
|
||
score = 0.5 # 基础分
|
||
|
||
# 检查章节结构(Markdown标题)
|
||
has_structure = bool(text.count('##') >= 2 or text.count('#') >= 3)
|
||
if has_structure:
|
||
score += 0.2
|
||
|
||
# 检查表格
|
||
has_tables = '|' in text and '---' in text
|
||
if has_tables:
|
||
score += 0.15
|
||
|
||
# 检查公式(LaTeX格式)
|
||
has_formulas = '$$' in text or '$' in text or '\\(' in text
|
||
if has_formulas:
|
||
score += 0.15
|
||
|
||
# 检查文本长度
|
||
if len(text) > 5000: # 至少5000字符
|
||
score += 0.1
|
||
|
||
# 检查乱码(简单启发式)
|
||
# 大量重复字符或特殊符号可能表示乱码
|
||
garbled_chars = sum(1 for c in text if ord(c) > 65535 or c in '<EFBFBD><EFBFBD>')
|
||
has_garbled = garbled_chars > len(text) * 0.05 # 超过5%
|
||
if has_garbled:
|
||
score -= 0.3
|
||
|
||
# 确保分数在0-1之间
|
||
score = max(0.0, min(1.0, score))
|
||
|
||
return {
|
||
"quality_score": score,
|
||
"has_structure": has_structure,
|
||
"has_tables": has_tables,
|
||
"has_formulas": has_formulas,
|
||
"has_garbled": has_garbled
|
||
}
|
||
|
||
|
||
def get_nougat_info() -> Dict[str, Any]:
|
||
"""
|
||
获取Nougat信息
|
||
|
||
Returns:
|
||
Nougat版本和状态信息
|
||
"""
|
||
try:
|
||
import nougat
|
||
version = getattr(nougat, '__version__', 'unknown')
|
||
return {
|
||
"available": True,
|
||
"version": version
|
||
}
|
||
|
||
except ImportError:
|
||
return {
|
||
"available": False,
|
||
"error": "Nougat未安装"
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"available": False,
|
||
"error": str(e)
|
||
}
|
||
|