/** * Protocol Context Service * 管理研究方案的上下文数据 * * @module agent/protocol/services/ProtocolContextService */ import { PrismaClient } from '@prisma/client'; import { ProtocolContextData, ProtocolStageCode, ScientificQuestionData, PICOData, StudyDesignData, SampleSizeData, EndpointsData, } from '../../types/index.js'; export class ProtocolContextService { private prisma: PrismaClient; constructor(prisma: PrismaClient) { this.prisma = prisma; } /** * 获取或创建上下文 */ async getOrCreateContext( conversationId: string, userId: string ): Promise { let context = await this.getContext(conversationId); if (!context) { context = await this.createContext(conversationId, userId); } return context; } /** * 获取上下文 */ async getContext(conversationId: string): Promise { const result = await this.prisma.protocolContext.findUnique({ where: { conversationId }, }); if (!result) return null; return this.mapToContextData(result); } /** * 创建新上下文 */ async createContext( conversationId: string, userId: string ): Promise { const result = await this.prisma.protocolContext.create({ data: { conversationId, userId, currentStage: 'scientific_question', status: 'in_progress', completedStages: [], }, }); return this.mapToContextData(result); } /** * 更新阶段数据 */ async updateStageData( conversationId: string, stageCode: ProtocolStageCode, data: Record ): Promise { const updateData: Record = {}; switch (stageCode) { case 'scientific_question': updateData.scientificQuestion = data; break; case 'pico': updateData.pico = data; break; case 'study_design': updateData.studyDesign = data; break; case 'sample_size': updateData.sampleSize = data; break; case 'endpoints': updateData.endpoints = data; break; } const result = await this.prisma.protocolContext.update({ where: { conversationId }, data: { ...updateData, lastActiveAt: new Date(), }, }); return this.mapToContextData(result); } /** * 标记阶段完成并更新当前阶段 */ async completeStage( conversationId: string, stageCode: ProtocolStageCode, nextStage?: ProtocolStageCode ): Promise { const context = await this.getContext(conversationId); if (!context) { throw new Error('Context not found'); } const completedStages = [...context.completedStages]; if (!completedStages.includes(stageCode)) { completedStages.push(stageCode); } const result = await this.prisma.protocolContext.update({ where: { conversationId }, data: { completedStages, currentStage: nextStage ?? context.currentStage, lastActiveAt: new Date(), }, }); return this.mapToContextData(result); } /** * 更新当前阶段 */ async updateCurrentStage( conversationId: string, stageCode: ProtocolStageCode ): Promise { const result = await this.prisma.protocolContext.update({ where: { conversationId }, data: { currentStage: stageCode, lastActiveAt: new Date(), }, }); return this.mapToContextData(result); } /** * 标记方案完成 */ async markCompleted(conversationId: string): Promise { const result = await this.prisma.protocolContext.update({ where: { conversationId }, data: { status: 'completed', lastActiveAt: new Date(), }, }); return this.mapToContextData(result); } /** * 检查是否所有阶段都已完成 */ isAllStagesCompleted(context: ProtocolContextData): boolean { const requiredStages: ProtocolStageCode[] = [ 'scientific_question', 'pico', 'study_design', 'sample_size', 'endpoints', ]; return requiredStages.every(stage => context.completedStages.includes(stage)); } /** * 检查是否可以生成研究方案(4/5 必填项) * 必填:科学问题、PICO、研究设计、观察指标 * 可选:样本量 */ canGenerateProtocol(context: ProtocolContextData): boolean { const requiredStages: ProtocolStageCode[] = [ 'scientific_question', 'pico', 'study_design', 'endpoints', ]; return requiredStages.every(stage => context.completedStages.includes(stage)); } /** * 获取缺失的必填阶段 */ getMissingRequiredStages(context: ProtocolContextData): ProtocolStageCode[] { const requiredStages: ProtocolStageCode[] = [ 'scientific_question', 'pico', 'study_design', 'endpoints', ]; return requiredStages.filter(stage => !context.completedStages.includes(stage)); } /** * 获取进度百分比 */ getProgress(context: ProtocolContextData): number { const totalStages = 5; return Math.round((context.completedStages.length / totalStages) * 100); } /** * 获取阶段状态列表 */ getStagesStatus(context: ProtocolContextData): Array<{ stageCode: ProtocolStageCode; stageName: string; status: 'completed' | 'current' | 'pending'; data: Record | null; }> { const stages: Array<{ code: ProtocolStageCode; name: string; dataKey: keyof ProtocolContextData; }> = [ { code: 'scientific_question', name: '科学问题梳理', dataKey: 'scientificQuestion' }, { code: 'pico', name: 'PICO要素', dataKey: 'pico' }, { code: 'study_design', name: '研究设计', dataKey: 'studyDesign' }, { code: 'sample_size', name: '样本量计算', dataKey: 'sampleSize' }, { code: 'endpoints', name: '观察指标', dataKey: 'endpoints' }, ]; return stages.map(stage => ({ stageCode: stage.code, stageName: stage.name, status: context.completedStages.includes(stage.code) ? 'completed' as const : context.currentStage === stage.code ? 'current' as const : 'pending' as const, data: context[stage.dataKey] as unknown as Record | null ?? null, })); } /** * 获取用于生成方案的完整数据 */ getGenerationData(context: ProtocolContextData): { scientificQuestion: ScientificQuestionData | null; pico: PICOData | null; studyDesign: StudyDesignData | null; sampleSize: SampleSizeData | null; endpoints: EndpointsData | null; } { return { scientificQuestion: context.scientificQuestion ?? null, pico: context.pico ?? null, studyDesign: context.studyDesign ?? null, sampleSize: context.sampleSize ?? null, endpoints: context.endpoints ?? null, }; } /** * 将数据库结果映射为上下文数据 */ private mapToContextData(result: { id: string; conversationId: string; userId: string; currentStage: string; status: string; scientificQuestion: unknown; pico: unknown; studyDesign: unknown; sampleSize: unknown; endpoints: unknown; completedStages: string[]; lastActiveAt: Date; createdAt: Date; updatedAt: Date; }): ProtocolContextData { return { id: result.id, conversationId: result.conversationId, userId: result.userId, currentStage: result.currentStage as ProtocolStageCode, status: result.status as ProtocolContextData['status'], scientificQuestion: result.scientificQuestion as ScientificQuestionData | undefined, pico: result.pico as PICOData | undefined, studyDesign: result.studyDesign as StudyDesignData | undefined, sampleSize: result.sampleSize as SampleSizeData | undefined, endpoints: result.endpoints as EndpointsData | undefined, completedStages: result.completedStages as ProtocolStageCode[], lastActiveAt: result.lastActiveAt, createdAt: result.createdAt, updatedAt: result.updatedAt, }; } }