- Add common/ layer for shared capabilities (LLM, RAG, document, middleware) - Add legacy/ layer for existing business code - Move files to new structure (controllers, routes, services) - Update index.ts for new route registration - System remains fully functional
429 lines
11 KiB
TypeScript
429 lines
11 KiB
TypeScript
/**
|
||
* Phase 3: 批处理模式 - 批处理控制器
|
||
*
|
||
* API路由:
|
||
* - POST /api/v1/batch/execute - 执行批处理任务
|
||
* - GET /api/v1/batch/tasks/:taskId - 获取任务状态
|
||
* - GET /api/v1/batch/tasks/:taskId/results - 获取任务结果
|
||
* - POST /api/v1/batch/tasks/:taskId/retry-failed - 重试失败项
|
||
*/
|
||
|
||
import { FastifyRequest, FastifyReply } from 'fastify';
|
||
import { executeBatchTask, retryFailedDocuments, BatchProgress } from '../services/batchService.js';
|
||
import { prisma } from '../../config/database.js';
|
||
import { ModelType } from '../../common/llm/adapters/types.js';
|
||
|
||
// ==================== 类型定义 ====================
|
||
|
||
interface ExecuteBatchBody {
|
||
kb_id: string;
|
||
document_ids: string[];
|
||
template_type: 'preset' | 'custom';
|
||
template_id?: string;
|
||
custom_prompt?: string;
|
||
model_type: ModelType;
|
||
task_name?: string;
|
||
}
|
||
|
||
interface TaskIdParams {
|
||
taskId: string;
|
||
}
|
||
|
||
// ==================== API处理器 ====================
|
||
|
||
/**
|
||
* POST /api/v1/batch/execute
|
||
* 执行批处理任务
|
||
*/
|
||
export async function executeBatch(
|
||
request: FastifyRequest<{ Body: ExecuteBatchBody }>,
|
||
reply: FastifyReply
|
||
) {
|
||
try {
|
||
// TODO: 从JWT获取userId
|
||
const userId = 'user-mock-001';
|
||
|
||
const {
|
||
kb_id,
|
||
document_ids,
|
||
template_type,
|
||
template_id,
|
||
custom_prompt,
|
||
model_type,
|
||
task_name,
|
||
} = request.body;
|
||
|
||
console.log('📦 [BatchController] 收到批处理请求', {
|
||
userId,
|
||
kbId: kb_id,
|
||
documentCount: document_ids.length,
|
||
templateType: template_type,
|
||
modelType: model_type,
|
||
});
|
||
|
||
// 验证参数
|
||
if (!kb_id || !document_ids || document_ids.length === 0) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: '缺少必要参数:kb_id 或 document_ids',
|
||
});
|
||
}
|
||
|
||
if (document_ids.length < 3) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: '文献数量不能少于3篇',
|
||
});
|
||
}
|
||
|
||
if (document_ids.length > 50) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: '文献数量不能超过50篇',
|
||
});
|
||
}
|
||
|
||
if (template_type === 'preset' && !template_id) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: '预设模板类型需要提供 template_id',
|
||
});
|
||
}
|
||
|
||
if (template_type === 'custom' && !custom_prompt) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: '自定义模板需要提供 custom_prompt',
|
||
});
|
||
}
|
||
|
||
// 验证模型类型
|
||
const validModels: ModelType[] = ['deepseek-v3', 'qwen3-72b', 'qwen-long'];
|
||
if (!validModels.includes(model_type)) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: `不支持的模型类型: ${model_type}`,
|
||
});
|
||
}
|
||
|
||
// 验证知识库是否存在
|
||
const kb = await prisma.knowledgeBase.findUnique({
|
||
where: { id: kb_id },
|
||
});
|
||
|
||
if (!kb) {
|
||
return reply.code(404).send({
|
||
success: false,
|
||
message: `知识库不存在: ${kb_id}`,
|
||
});
|
||
}
|
||
|
||
// 验证文档是否都存在
|
||
const documents = await prisma.document.findMany({
|
||
where: {
|
||
id: { in: document_ids },
|
||
kbId: kb_id,
|
||
},
|
||
});
|
||
|
||
if (documents.length !== document_ids.length) {
|
||
return reply.code(400).send({
|
||
success: false,
|
||
message: `部分文档不存在或不属于该知识库`,
|
||
});
|
||
}
|
||
|
||
// 获取WebSocket实例(用于进度推送)
|
||
const io = (request.server as any).io;
|
||
|
||
// 先创建任务记录获取taskId
|
||
const taskPreview = await prisma.batchTask.create({
|
||
data: {
|
||
userId,
|
||
kbId: kb_id,
|
||
name: task_name || `批处理任务_${new Date().toLocaleString('zh-CN')}`,
|
||
templateType: template_type,
|
||
templateId: template_id || null,
|
||
prompt: custom_prompt || template_id || '',
|
||
status: 'processing',
|
||
totalDocuments: document_ids.length,
|
||
modelType: model_type,
|
||
concurrency: 3,
|
||
startedAt: new Date(),
|
||
},
|
||
});
|
||
|
||
const taskId = taskPreview.id;
|
||
console.log(`✅ [BatchController] 创建任务: ${taskId}`);
|
||
|
||
// 执行批处理任务(异步)
|
||
executeBatchTask({
|
||
userId,
|
||
kbId: kb_id,
|
||
documentIds: document_ids,
|
||
templateType: template_type,
|
||
templateId: template_id,
|
||
customPrompt: custom_prompt,
|
||
modelType: model_type,
|
||
taskName: task_name,
|
||
existingTaskId: taskId, // 使用已创建的任务ID
|
||
onProgress: (progress: BatchProgress) => {
|
||
// WebSocket推送进度
|
||
if (io) {
|
||
io.to(userId).emit('batch-progress', progress);
|
||
}
|
||
},
|
||
})
|
||
.then((result) => {
|
||
console.log(`🎉 [BatchController] 批处理任务完成: ${result.taskId}`);
|
||
// 推送完成事件
|
||
if (io) {
|
||
io.to(userId).emit('batch-completed', {
|
||
task_id: result.taskId,
|
||
status: result.status,
|
||
});
|
||
}
|
||
})
|
||
.catch((error) => {
|
||
console.error(`❌ [BatchController] 批处理任务失败:`, error);
|
||
// 推送失败事件
|
||
if (io) {
|
||
io.to(userId).emit('batch-failed', {
|
||
task_id: 'unknown',
|
||
error: error.message,
|
||
});
|
||
}
|
||
});
|
||
|
||
// 立即返回任务ID(任务在后台执行)
|
||
reply.send({
|
||
success: true,
|
||
message: '批处理任务已开始',
|
||
data: {
|
||
task_id: taskId,
|
||
status: 'processing',
|
||
websocket_event: 'batch-progress',
|
||
},
|
||
});
|
||
} catch (error: any) {
|
||
console.error('❌ [BatchController] 执行批处理失败:', error);
|
||
reply.code(500).send({
|
||
success: false,
|
||
message: error.message || '执行批处理任务失败',
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* GET /api/v1/batch/tasks/:taskId
|
||
* 获取任务状态
|
||
*/
|
||
export async function getTask(
|
||
request: FastifyRequest<{ Params: TaskIdParams }>,
|
||
reply: FastifyReply
|
||
) {
|
||
try {
|
||
const { taskId } = request.params;
|
||
|
||
const task = await prisma.batchTask.findUnique({
|
||
where: { id: taskId },
|
||
select: {
|
||
id: true,
|
||
name: true,
|
||
status: true,
|
||
totalDocuments: true,
|
||
completedCount: true,
|
||
failedCount: true,
|
||
modelType: true,
|
||
startedAt: true,
|
||
completedAt: true,
|
||
durationSeconds: true,
|
||
createdAt: true,
|
||
},
|
||
});
|
||
|
||
if (!task) {
|
||
return reply.code(404).send({
|
||
success: false,
|
||
message: `任务不存在: ${taskId}`,
|
||
});
|
||
}
|
||
|
||
reply.send({
|
||
success: true,
|
||
data: {
|
||
id: task.id,
|
||
name: task.name,
|
||
status: task.status,
|
||
total_documents: task.totalDocuments,
|
||
completed_count: task.completedCount,
|
||
failed_count: task.failedCount,
|
||
model_type: task.modelType,
|
||
started_at: task.startedAt,
|
||
completed_at: task.completedAt,
|
||
duration_seconds: task.durationSeconds,
|
||
created_at: task.createdAt,
|
||
},
|
||
});
|
||
} catch (error: any) {
|
||
console.error('❌ [BatchController] 获取任务失败:', error);
|
||
reply.code(500).send({
|
||
success: false,
|
||
message: error.message || '获取任务失败',
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* GET /api/v1/batch/tasks/:taskId/results
|
||
* 获取任务结果
|
||
*/
|
||
export async function getTaskResults(
|
||
request: FastifyRequest<{ Params: TaskIdParams }>,
|
||
reply: FastifyReply
|
||
) {
|
||
try {
|
||
const { taskId } = request.params;
|
||
|
||
// 获取任务信息
|
||
const task = await prisma.batchTask.findUnique({
|
||
where: { id: taskId },
|
||
include: {
|
||
results: {
|
||
include: {
|
||
document: {
|
||
select: {
|
||
filename: true,
|
||
tokensCount: true,
|
||
},
|
||
},
|
||
},
|
||
orderBy: {
|
||
createdAt: 'asc',
|
||
},
|
||
},
|
||
},
|
||
});
|
||
|
||
if (!task) {
|
||
return reply.code(404).send({
|
||
success: false,
|
||
message: `任务不存在: ${taskId}`,
|
||
});
|
||
}
|
||
|
||
// 格式化结果
|
||
const results = task.results.map((r, index) => ({
|
||
id: r.id,
|
||
index: index + 1,
|
||
document_id: r.documentId,
|
||
document_name: r.document.filename,
|
||
status: r.status,
|
||
data: r.data,
|
||
raw_output: r.rawOutput,
|
||
error_message: r.errorMessage,
|
||
processing_time_ms: r.processingTimeMs,
|
||
tokens_used: r.tokensUsed,
|
||
created_at: r.createdAt,
|
||
}));
|
||
|
||
reply.send({
|
||
success: true,
|
||
data: {
|
||
task: {
|
||
id: task.id,
|
||
name: task.name,
|
||
status: task.status,
|
||
template_type: task.templateType,
|
||
template_id: task.templateId,
|
||
total_documents: task.totalDocuments,
|
||
completed_count: task.completedCount,
|
||
failed_count: task.failedCount,
|
||
duration_seconds: task.durationSeconds,
|
||
created_at: task.createdAt,
|
||
completed_at: task.completedAt,
|
||
},
|
||
results,
|
||
},
|
||
});
|
||
} catch (error: any) {
|
||
console.error('❌ [BatchController] 获取任务结果失败:', error);
|
||
reply.code(500).send({
|
||
success: false,
|
||
message: error.message || '获取任务结果失败',
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* POST /api/v1/batch/tasks/:taskId/retry-failed
|
||
* 重试失败的文档
|
||
*/
|
||
export async function retryFailed(
|
||
request: FastifyRequest<{ Params: TaskIdParams }>,
|
||
reply: FastifyReply
|
||
) {
|
||
try {
|
||
const { taskId } = request.params;
|
||
const userId = 'user-mock-001'; // TODO: 从JWT获取
|
||
|
||
// 获取WebSocket实例
|
||
const io = (request.server as any).io;
|
||
|
||
// 执行重试(异步)
|
||
retryFailedDocuments(taskId, (progress: BatchProgress) => {
|
||
if (io) {
|
||
io.to(userId).emit('batch-progress', progress);
|
||
}
|
||
})
|
||
.then((result) => {
|
||
console.log(`✅ [BatchController] 重试完成: ${result.retriedCount}篇`);
|
||
})
|
||
.catch((error) => {
|
||
console.error(`❌ [BatchController] 重试失败:`, error);
|
||
});
|
||
|
||
reply.send({
|
||
success: true,
|
||
message: '已开始重试失败的文档',
|
||
});
|
||
} catch (error: any) {
|
||
console.error('❌ [BatchController] 重试失败:', error);
|
||
reply.code(500).send({
|
||
success: false,
|
||
message: error.message || '重试失败',
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* GET /api/v1/batch/templates
|
||
* 获取所有预设模板
|
||
*/
|
||
export async function getTemplates(
|
||
request: FastifyRequest,
|
||
reply: FastifyReply
|
||
) {
|
||
try {
|
||
const { getAllTemplates } = await import('../templates/clinicalResearch.js');
|
||
const templates = getAllTemplates();
|
||
|
||
reply.send({
|
||
success: true,
|
||
data: templates.map(t => ({
|
||
id: t.id,
|
||
name: t.name,
|
||
description: t.description,
|
||
output_fields: t.outputFields,
|
||
})),
|
||
});
|
||
} catch (error: any) {
|
||
console.error('❌ [BatchController] 获取模板失败:', error);
|
||
reply.code(500).send({
|
||
success: false,
|
||
message: error.message || '获取模板失败',
|
||
});
|
||
}
|
||
}
|
||
|