Files
AIclinicalresearch/backend/src/modules/agent/protocol/services/ProtocolContextService.ts
HaHafeng 303dd78c54 feat(aia): Protocol Agent MVP complete with one-click generation and Word export
- Add one-click research protocol generation with streaming output

- Implement Word document export via Pandoc integration

- Add dynamic dual-panel layout with resizable split pane

- Implement collapsible content for StatePanel stages

- Add conversation history management with title auto-update

- Fix scroll behavior, markdown rendering, and UI layout issues

- Simplify conversation creation logic for reliability
2026-01-25 19:16:36 +08:00

320 lines
8.2 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Protocol Context Service
* 管理研究方案的上下文数据
*
* @module agent/protocol/services/ProtocolContextService
*/
import { PrismaClient } from '@prisma/client';
import {
ProtocolContextData,
ProtocolStageCode,
ScientificQuestionData,
PICOData,
StudyDesignData,
SampleSizeData,
EndpointsData,
} from '../../types/index.js';
export class ProtocolContextService {
private prisma: PrismaClient;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 获取或创建上下文
*/
async getOrCreateContext(
conversationId: string,
userId: string
): Promise<ProtocolContextData> {
let context = await this.getContext(conversationId);
if (!context) {
context = await this.createContext(conversationId, userId);
}
return context;
}
/**
* 获取上下文
*/
async getContext(conversationId: string): Promise<ProtocolContextData | null> {
const result = await this.prisma.protocolContext.findUnique({
where: { conversationId },
});
if (!result) return null;
return this.mapToContextData(result);
}
/**
* 创建新上下文
*/
async createContext(
conversationId: string,
userId: string
): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.create({
data: {
conversationId,
userId,
currentStage: 'scientific_question',
status: 'in_progress',
completedStages: [],
},
});
return this.mapToContextData(result);
}
/**
* 更新阶段数据
*/
async updateStageData(
conversationId: string,
stageCode: ProtocolStageCode,
data: Record<string, unknown>
): Promise<ProtocolContextData> {
const updateData: Record<string, unknown> = {};
switch (stageCode) {
case 'scientific_question':
updateData.scientificQuestion = data;
break;
case 'pico':
updateData.pico = data;
break;
case 'study_design':
updateData.studyDesign = data;
break;
case 'sample_size':
updateData.sampleSize = data;
break;
case 'endpoints':
updateData.endpoints = data;
break;
}
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
...updateData,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 标记阶段完成并更新当前阶段
*/
async completeStage(
conversationId: string,
stageCode: ProtocolStageCode,
nextStage?: ProtocolStageCode
): Promise<ProtocolContextData> {
const context = await this.getContext(conversationId);
if (!context) {
throw new Error('Context not found');
}
const completedStages = [...context.completedStages];
if (!completedStages.includes(stageCode)) {
completedStages.push(stageCode);
}
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
completedStages,
currentStage: nextStage ?? context.currentStage,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 更新当前阶段
*/
async updateCurrentStage(
conversationId: string,
stageCode: ProtocolStageCode
): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
currentStage: stageCode,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 标记方案完成
*/
async markCompleted(conversationId: string): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
status: 'completed',
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 检查是否所有阶段都已完成
*/
isAllStagesCompleted(context: ProtocolContextData): boolean {
const requiredStages: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'sample_size',
'endpoints',
];
return requiredStages.every(stage => context.completedStages.includes(stage));
}
/**
* 检查是否可以生成研究方案4/5 必填项)
* 必填科学问题、PICO、研究设计、观察指标
* 可选:样本量
*/
canGenerateProtocol(context: ProtocolContextData): boolean {
const requiredStages: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'endpoints',
];
return requiredStages.every(stage => context.completedStages.includes(stage));
}
/**
* 获取缺失的必填阶段
*/
getMissingRequiredStages(context: ProtocolContextData): ProtocolStageCode[] {
const requiredStages: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'endpoints',
];
return requiredStages.filter(stage => !context.completedStages.includes(stage));
}
/**
* 获取进度百分比
*/
getProgress(context: ProtocolContextData): number {
const totalStages = 5;
return Math.round((context.completedStages.length / totalStages) * 100);
}
/**
* 获取阶段状态列表
*/
getStagesStatus(context: ProtocolContextData): Array<{
stageCode: ProtocolStageCode;
stageName: string;
status: 'completed' | 'current' | 'pending';
data: Record<string, unknown> | null;
}> {
const stages: Array<{
code: ProtocolStageCode;
name: string;
dataKey: keyof ProtocolContextData;
}> = [
{ code: 'scientific_question', name: '科学问题梳理', dataKey: 'scientificQuestion' },
{ code: 'pico', name: 'PICO要素', dataKey: 'pico' },
{ code: 'study_design', name: '研究设计', dataKey: 'studyDesign' },
{ code: 'sample_size', name: '样本量计算', dataKey: 'sampleSize' },
{ code: 'endpoints', name: '观察指标', dataKey: 'endpoints' },
];
return stages.map(stage => ({
stageCode: stage.code,
stageName: stage.name,
status: context.completedStages.includes(stage.code)
? 'completed' as const
: context.currentStage === stage.code
? 'current' as const
: 'pending' as const,
data: context[stage.dataKey] as unknown as Record<string, unknown> | null ?? null,
}));
}
/**
* 获取用于生成方案的完整数据
*/
getGenerationData(context: ProtocolContextData): {
scientificQuestion: ScientificQuestionData | null;
pico: PICOData | null;
studyDesign: StudyDesignData | null;
sampleSize: SampleSizeData | null;
endpoints: EndpointsData | null;
} {
return {
scientificQuestion: context.scientificQuestion ?? null,
pico: context.pico ?? null,
studyDesign: context.studyDesign ?? null,
sampleSize: context.sampleSize ?? null,
endpoints: context.endpoints ?? null,
};
}
/**
* 将数据库结果映射为上下文数据
*/
private mapToContextData(result: {
id: string;
conversationId: string;
userId: string;
currentStage: string;
status: string;
scientificQuestion: unknown;
pico: unknown;
studyDesign: unknown;
sampleSize: unknown;
endpoints: unknown;
completedStages: string[];
lastActiveAt: Date;
createdAt: Date;
updatedAt: Date;
}): ProtocolContextData {
return {
id: result.id,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage as ProtocolStageCode,
status: result.status as ProtocolContextData['status'],
scientificQuestion: result.scientificQuestion as ScientificQuestionData | undefined,
pico: result.pico as PICOData | undefined,
studyDesign: result.studyDesign as StudyDesignData | undefined,
sampleSize: result.sampleSize as SampleSizeData | undefined,
endpoints: result.endpoints as EndpointsData | undefined,
completedStages: result.completedStages as ProtocolStageCode[],
lastActiveAt: result.lastActiveAt,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
}