/** * Phase I — PICO 推断服务 * * 调用 LLM (SSA_PICO_INFERENCE prompt) 从数据概览推断 PICO 结构。 * 写入 SessionBlackboard.picoInference,标记为 ai_inferred。 * * 安全措施: * - Zod 校验 LLM 输出 * - jsonrepair 容错 * - H3: 观察性研究允许 intervention/comparison 为 null */ import { logger } from '../../../common/logging/index.js'; import { LLMFactory } from '../../../common/llm/adapters/LLMFactory.js'; import { getPromptService } from '../../../common/prompt/index.js'; import { prisma } from '../../../config/database.js'; import { jsonrepair } from 'jsonrepair'; import type { Message } from '../../../common/llm/adapters/types.js'; import { sessionBlackboardService } from './SessionBlackboardService.js'; import { PicoInferenceSchema, type PicoInference, type DataOverview, type VariableDictEntry, } from '../types/session-blackboard.types.js'; const MAX_RETRIES = 1; export class PicoInferenceService { /** * 从 DataOverview 推断 PICO 结构并写入黑板。 */ async inferFromOverview( sessionId: string, overview: DataOverview, dictionary: VariableDictEntry[], ): Promise { try { logger.info('[SSA:PICO] Starting inference', { sessionId }); const promptService = getPromptService(prisma); const dataOverviewSummary = this.buildOverviewSummary(overview); const variableList = this.buildVariableList(dictionary); const rendered = await promptService.get('SSA_PICO_INFERENCE', { dataOverviewSummary, variableList, }); const adapter = LLMFactory.getAdapter( (rendered.modelConfig?.model as any) || 'deepseek-v3' ); const messages: Message[] = [ { role: 'system', content: rendered.content }, { role: 'user', content: '请根据以上数据概览推断 PICO 结构。' }, ]; let pico: PicoInference | null = null; for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) { try { const response = await adapter.chat(messages, { temperature: rendered.modelConfig?.temperature ?? 0.3, maxTokens: rendered.modelConfig?.maxTokens ?? 1024, }); const raw = this.robustJsonParse(response.content); const validated = PicoInferenceSchema.parse({ ...raw, status: 'ai_inferred', }); pico = validated; break; } catch (err: any) { logger.warn('[SSA:PICO] LLM attempt failed', { attempt, error: err.message, }); if (attempt === MAX_RETRIES) throw err; } } if (pico) { await sessionBlackboardService.confirmPico(sessionId, { population: pico.population, intervention: pico.intervention, comparison: pico.comparison, outcome: pico.outcome, }); logger.info('[SSA:PICO] Inference complete', { sessionId, confidence: pico.confidence, hasIntervention: pico.intervention !== null, }); } return pico; } catch (error: any) { logger.error('[SSA:PICO] Inference failed', { sessionId, error: error.message, }); return null; } } private buildOverviewSummary(overview: DataOverview): string { const s = overview.profile.summary; const lines = [ `数据集: ${s.totalRows} 行, ${s.totalColumns} 列`, `类型分布: 数值型 ${s.numericColumns}, 分类型 ${s.categoricalColumns}, 日期型 ${s.datetimeColumns}, 文本型 ${s.textColumns}`, `整体缺失率: ${s.overallMissingRate}%`, `完整病例数: ${overview.completeCaseCount}`, ]; const nonNormal = overview.normalityTests ?.filter(t => !t.isNormal) .map(t => t.variable); if (nonNormal && nonNormal.length > 0) { lines.push(`非正态分布变量: ${nonNormal.join(', ')}`); } return lines.join('\n'); } private buildVariableList(dict: VariableDictEntry[]): string { return dict .filter(v => !v.isIdLike) .map(v => { const type = v.confirmedType ?? v.inferredType; const label = v.label ? ` (${v.label})` : ''; return `- ${v.name}: ${type}${label}`; }) .join('\n'); } private robustJsonParse(text: string): any { let cleaned = text.trim(); const fenceMatch = cleaned.match(/```(?:json)?\s*([\s\S]*?)```/); if (fenceMatch) { cleaned = fenceMatch[1].trim(); } try { return JSON.parse(cleaned); } catch { return JSON.parse(jsonrepair(cleaned)); } } } export const picoInferenceService = new PicoInferenceService();