Files
AIclinicalresearch/extraction_service/services/nougat_extractor.py

242 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Nougat提取服务
使用Nougat OCR提取学术PDF的高质量文本
保留表格、公式等结构信息
"""
import subprocess
import os
from pathlib import Path
from typing import Dict, Any, Optional, Callable
from loguru import logger
def check_nougat_available() -> bool:
"""
检查Nougat是否已安装
Returns:
True if Nougat可用
"""
try:
# 方法1: 尝试导入nougat模块
import nougat
logger.info(f"Nougat module is available (version: {getattr(nougat, '__version__', 'unknown')})")
return True
except ImportError:
logger.warning("Nougat module not found")
return False
except Exception as e:
logger.error(f"检查Nougat失败: {str(e)}")
return False
def extract_pdf_nougat(
file_path: str,
output_dir: Optional[str] = None,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> Dict[str, Any]:
"""
使用Nougat提取PDF文本
Args:
file_path: PDF文件路径
output_dir: 输出目录,默认为临时目录
progress_callback: 进度回调函数 (current_page, total_pages)
Returns:
{
"success": True,
"method": "nougat",
"text": "提取的Markdown文本",
"format": "markdown",
"metadata": {
"page_count": 20,
"char_count": 50000,
"quality_score": 0.95,
"has_tables": True,
"has_formulas": True
}
}
"""
try:
# 检查Nougat是否可用
if not check_nougat_available():
raise Exception("Nougat未安装请先安装pip install nougat-ocr")
logger.info(f"开始使用Nougat提取: {file_path}")
# 准备输出目录
if output_dir is None:
output_dir = os.path.join(os.path.dirname(file_path), "nougat_output")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 构建Nougat命令
# nougat命令格式nougat <pdf_path> -o <output_dir>
cmd = [
'nougat',
file_path,
'-o', output_dir,
'--markdown', # 输出Markdown格式
'--no-skipping' # 不跳过任何页面
]
logger.info(f"执行命令: {' '.join(cmd)}")
# 执行Nougat
# 注意Nougat可能需要较长时间1-2分钟/20页
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# 等待完成
stdout, stderr = process.communicate(timeout=300) # 5分钟超时
if process.returncode != 0:
logger.error(f"Nougat执行失败: {stderr}")
raise Exception(f"Nougat执行失败: {stderr}")
# 读取输出文件
# Nougat会生成 <filename>.mmd 文件
pdf_name = Path(file_path).stem
output_file = Path(output_dir) / f"{pdf_name}.mmd"
if not output_file.exists():
raise Exception(f"Nougat输出文件不存在: {output_file}")
with open(output_file, 'r', encoding='utf-8') as f:
markdown_text = f.read()
# 评估质量
quality_result = evaluate_nougat_quality(markdown_text)
logger.info(f"Nougat提取完成: 质量={quality_result['quality_score']:.2f}")
return {
"success": True,
"method": "nougat",
"text": markdown_text,
"format": "markdown",
"metadata": {
"char_count": len(markdown_text),
"quality_score": quality_result['quality_score'],
"has_tables": quality_result['has_tables'],
"has_formulas": quality_result['has_formulas'],
"has_structure": quality_result['has_structure']
}
}
except subprocess.TimeoutExpired:
logger.error("Nougat处理超时>5分钟")
return {
"success": False,
"error": "处理超时",
"method": "nougat"
}
except Exception as e:
logger.error(f"Nougat提取失败: {str(e)}")
return {
"success": False,
"error": str(e),
"method": "nougat"
}
def evaluate_nougat_quality(text: str) -> Dict[str, Any]:
"""
评估Nougat提取质量
评分标准:
- 基础分0.5
- 有章节结构:+0.2
- 有表格:+0.15
- 有公式:+0.15
- 文本长度充足:+0.1
- 乱码检测:-0.3
Args:
text: Nougat提取的Markdown文本
Returns:
{
"quality_score": 0.92,
"has_structure": True,
"has_tables": True,
"has_formulas": True,
"has_garbled": False
}
"""
score = 0.5 # 基础分
# 检查章节结构Markdown标题
has_structure = bool(text.count('##') >= 2 or text.count('#') >= 3)
if has_structure:
score += 0.2
# 检查表格
has_tables = '|' in text and '---' in text
if has_tables:
score += 0.15
# 检查公式LaTeX格式
has_formulas = '$$' in text or '$' in text or '\\(' in text
if has_formulas:
score += 0.15
# 检查文本长度
if len(text) > 5000: # 至少5000字符
score += 0.1
# 检查乱码(简单启发式)
# 大量重复字符或特殊符号可能表示乱码
garbled_chars = sum(1 for c in text if ord(c) > 65535 or c in '<EFBFBD><EFBFBD>')
has_garbled = garbled_chars > len(text) * 0.05 # 超过5%
if has_garbled:
score -= 0.3
# 确保分数在0-1之间
score = max(0.0, min(1.0, score))
return {
"quality_score": score,
"has_structure": has_structure,
"has_tables": has_tables,
"has_formulas": has_formulas,
"has_garbled": has_garbled
}
def get_nougat_info() -> Dict[str, Any]:
"""
获取Nougat信息
Returns:
Nougat版本和状态信息
"""
try:
import nougat
version = getattr(nougat, '__version__', 'unknown')
return {
"available": True,
"version": version
}
except ImportError:
return {
"available": False,
"error": "Nougat未安装"
}
except Exception as e:
return {
"available": False,
"error": str(e)
}