Files
AIclinicalresearch/backend/src/modules/iit-manager/engines/SkillRunner.ts

756 lines
21 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* SkillRunner - 规则调度器
*
* 功能:
* - 根据触发类型加载和执行 Skills
* - 协调 HardRuleEngine 和 SoftRuleEngine
* - 实现漏斗式执行策略Blocking → Hard → Soft
* - 聚合质控结果
*
* 设计原则:
* - 可插拔:通过 Skill 配置动态加载规则
* - 成本控制:阻断性检查优先,失败则跳过 AI 检查
* - 统一入口:所有触发类型使用相同的执行逻辑
*/
import { PrismaClient } from '@prisma/client';
import { logger } from '../../../common/logging/index.js';
import { HardRuleEngine, createHardRuleEngine, QCResult, QCRule } from './HardRuleEngine.js';
import { SoftRuleEngine, createSoftRuleEngine, SoftRuleCheck, SoftRuleEngineResult } from './SoftRuleEngine.js';
import { RedcapAdapter } from '../adapters/RedcapAdapter.js';
import jsonLogic from 'json-logic-js';
const prisma = new PrismaClient();
// ============================================================
// 类型定义
// ============================================================
/**
* 触发类型
*/
export type TriggerType = 'webhook' | 'cron' | 'manual';
/**
* 规则类型
*/
export type RuleType = 'HARD_RULE' | 'LLM_CHECK' | 'HYBRID';
/**
* Skill 执行结果
*/
export interface SkillResult {
skillId: string;
skillName: string;
skillType: string;
ruleType: RuleType;
status: 'PASS' | 'FAIL' | 'WARNING' | 'UNCERTAIN';
issues: SkillIssue[];
executionTimeMs: number;
}
/**
* 问题项
*/
export interface SkillIssue {
ruleId: string;
ruleName: string;
field?: string | string[];
message: string;
llmMessage?: string; // V2.1: LLM 友好的自包含消息
severity: 'critical' | 'warning' | 'info';
actualValue?: any;
expectedValue?: string; // V2.1: 期望值(人类可读)
evidence?: Record<string, any>; // V2.1: 结构化证据
confidence?: number;
}
/**
* SkillRunner 执行结果
*
* V3.1: 支持事件级质控,每个 record+event 作为独立单元
*/
export interface SkillRunResult {
projectId: string;
recordId: string;
// V3.1: 事件级质控支持
eventName?: string; // REDCap 事件唯一标识
eventLabel?: string; // 事件显示名称(如"筛选期"
forms?: string[]; // 该事件包含的表单列表
triggerType: TriggerType;
timestamp: string;
overallStatus: 'PASS' | 'FAIL' | 'WARNING' | 'UNCERTAIN';
summary: {
totalSkills: number;
passed: number;
failed: number;
warnings: number;
uncertain: number;
blockedByLevel1: boolean;
};
skillResults: SkillResult[];
allIssues: SkillIssue[];
criticalIssues: SkillIssue[];
warningIssues: SkillIssue[];
executionTimeMs: number;
}
/**
* SkillRunner 选项
*
* V3.1: 支持事件级过滤
*/
export interface SkillRunnerOptions {
recordId?: string;
eventName?: string; // V3.1: 指定事件
formName?: string;
skipSoftRules?: boolean; // 跳过 LLM 检查(用于快速检查)
}
// ============================================================
// SkillRunner 实现
// ============================================================
export class SkillRunner {
private projectId: string;
private redcapAdapter?: RedcapAdapter;
constructor(projectId: string) {
this.projectId = projectId;
}
/**
* 初始化 REDCap 适配器
*/
private async initRedcapAdapter(): Promise<RedcapAdapter> {
if (this.redcapAdapter) {
return this.redcapAdapter;
}
const project = await prisma.iitProject.findUnique({
where: { id: this.projectId },
select: { redcapUrl: true, redcapApiToken: true },
});
if (!project) {
throw new Error(`项目不存在: ${this.projectId}`);
}
this.redcapAdapter = new RedcapAdapter(project.redcapUrl, project.redcapApiToken);
return this.redcapAdapter;
}
/**
* 按触发类型执行 Skills
*
* @param triggerType 触发类型
* @param options 执行选项
* @returns 执行结果
*/
async runByTrigger(
triggerType: TriggerType,
options?: SkillRunnerOptions
): Promise<SkillRunResult[]> {
const startTime = Date.now();
logger.info('[SkillRunner] Starting execution', {
projectId: this.projectId,
triggerType,
options,
});
// 1. 加载启用的 Skills
const skills = await this.loadSkills(triggerType, options?.formName);
if (skills.length === 0) {
logger.warn('[SkillRunner] No active skills found', {
projectId: this.projectId,
triggerType,
});
return [];
}
// 2. 按优先级排序priority 越小越优先blocking 级别最优先)
skills.sort((a, b) => {
if (a.level === 'blocking' && b.level !== 'blocking') return -1;
if (a.level !== 'blocking' && b.level === 'blocking') return 1;
return a.priority - b.priority;
});
// 3. 获取要处理的记录V3.1: 事件级数据)
const records = await this.getRecordsToProcess(options);
// 4. 对每条记录+事件执行所有 Skills
const results: SkillRunResult[] = [];
for (const record of records) {
const result = await this.executeSkillsForRecord(
record.recordId,
record.eventName,
record.eventLabel,
record.forms,
record.data,
skills,
triggerType,
options
);
results.push(result);
// 保存质控日志
await this.saveQcLog(result);
}
const totalTime = Date.now() - startTime;
logger.info('[SkillRunner] Execution completed', {
projectId: this.projectId,
triggerType,
recordEventCount: records.length,
totalTimeMs: totalTime,
});
return results;
}
/**
* 加载 Skills
*/
private async loadSkills(
triggerType: TriggerType,
formName?: string
): Promise<Array<{
id: string;
skillType: string;
name: string;
ruleType: string;
level: string;
priority: number;
config: any;
requiredTags: string[];
}>> {
const where: any = {
projectId: this.projectId,
isActive: true,
};
// 根据触发类型过滤
if (triggerType === 'webhook') {
where.triggerType = 'webhook';
} else if (triggerType === 'cron') {
where.triggerType = { in: ['cron', 'webhook'] }; // Cron 也执行 webhook 规则
}
// manual 执行所有规则
const skills = await prisma.iitSkill.findMany({
where,
select: {
id: true,
skillType: true,
name: true,
ruleType: true,
level: true,
priority: true,
config: true,
requiredTags: true,
},
});
// 如果指定了 formName过滤相关的 Skills
if (formName) {
return skills.filter(skill => {
const config = skill.config as any;
// 检查规则中是否有与该表单相关的规则
if (config?.rules) {
return config.rules.some((rule: any) =>
!rule.formName || rule.formName === formName
);
}
return true; // 没有 formName 限制的规则默认包含
});
}
return skills;
}
/**
* 获取要处理的记录(事件级别)
*
* V3.1: 返回事件级数据,每个 record+event 作为独立单元
* 不再合并事件数据,确保每个访视独立质控
*/
private async getRecordsToProcess(
options?: SkillRunnerOptions
): Promise<Array<{
recordId: string;
eventName: string;
eventLabel: string;
forms: string[];
data: Record<string, any>;
}>> {
const adapter = await this.initRedcapAdapter();
// V3.1: 使用 getAllRecordsByEvent 获取事件级数据
const eventRecords = await adapter.getAllRecordsByEvent({
recordId: options?.recordId,
eventName: options?.eventName,
});
return eventRecords.map(r => ({
recordId: r.recordId,
eventName: r.eventName,
eventLabel: r.eventLabel,
forms: r.forms,
data: r.data,
}));
}
/**
* 对单条记录+事件执行所有 Skills
*
* V3.1: 支持事件级质控,根据规则配置过滤适用的规则
*/
private async executeSkillsForRecord(
recordId: string,
eventName: string,
eventLabel: string,
forms: string[],
data: Record<string, any>,
skills: Array<{
id: string;
skillType: string;
name: string;
ruleType: string;
level: string;
priority: number;
config: any;
requiredTags: string[];
}>,
triggerType: TriggerType,
options?: SkillRunnerOptions
): Promise<SkillRunResult> {
const startTime = Date.now();
const skillResults: SkillResult[] = [];
const allIssues: SkillIssue[] = [];
const criticalIssues: SkillIssue[] = [];
const warningIssues: SkillIssue[] = [];
let blockedByLevel1 = false;
// 漏斗式执行
for (const skill of skills) {
const ruleType = skill.ruleType as RuleType;
// 如果已被阻断且当前不是 blocking 级别,跳过 LLM 检查
if (blockedByLevel1 && ruleType === 'LLM_CHECK') {
logger.debug('[SkillRunner] Skipping LLM check due to blocking failure', {
skillId: skill.id,
recordId,
eventName,
});
continue;
}
// 如果选项要求跳过软规则
if (options?.skipSoftRules && ruleType === 'LLM_CHECK') {
continue;
}
// V3.1: 执行 Skill传入事件和表单信息用于规则过滤
const result = await this.executeSkill(skill, recordId, eventName, forms, data);
skillResults.push(result);
// 收集问题
for (const issue of result.issues) {
allIssues.push(issue);
if (issue.severity === 'critical') {
criticalIssues.push(issue);
} else if (issue.severity === 'warning') {
warningIssues.push(issue);
}
}
// 检查是否触发阻断
if (skill.level === 'blocking' && result.status === 'FAIL') {
blockedByLevel1 = true;
logger.info('[SkillRunner] Blocking check failed, skipping AI checks', {
skillId: skill.id,
recordId,
eventName,
});
}
}
// 计算整体状态
let overallStatus: 'PASS' | 'FAIL' | 'WARNING' | 'UNCERTAIN' = 'PASS';
if (criticalIssues.length > 0) {
overallStatus = 'FAIL';
} else if (skillResults.some(r => r.status === 'UNCERTAIN')) {
overallStatus = 'UNCERTAIN';
} else if (warningIssues.length > 0) {
overallStatus = 'WARNING';
}
const executionTimeMs = Date.now() - startTime;
return {
projectId: this.projectId,
recordId,
// V3.1: 包含事件信息
eventName,
eventLabel,
forms,
triggerType,
timestamp: new Date().toISOString(),
overallStatus,
summary: {
totalSkills: skillResults.length,
passed: skillResults.filter(r => r.status === 'PASS').length,
failed: skillResults.filter(r => r.status === 'FAIL').length,
warnings: skillResults.filter(r => r.status === 'WARNING').length,
uncertain: skillResults.filter(r => r.status === 'UNCERTAIN').length,
blockedByLevel1,
},
skillResults,
allIssues,
criticalIssues,
warningIssues,
executionTimeMs,
};
}
/**
* 执行单个 Skill
*
* V3.1: 支持事件级质控,根据规则配置过滤适用的规则
*/
private async executeSkill(
skill: {
id: string;
skillType: string;
name: string;
ruleType: string;
config: any;
requiredTags: string[];
},
recordId: string,
eventName: string,
forms: string[],
data: Record<string, any>
): Promise<SkillResult> {
const startTime = Date.now();
const ruleType = skill.ruleType as RuleType;
const config = skill.config as any;
const issues: SkillIssue[] = [];
let status: 'PASS' | 'FAIL' | 'WARNING' | 'UNCERTAIN' = 'PASS';
try {
if (ruleType === 'HARD_RULE') {
// 使用 HardRuleEngine
const engine = await createHardRuleEngine(this.projectId);
// 临时注入规则(如果 config 中有)
if (config?.rules) {
// V3.1: 过滤适用于当前事件/表单的规则
const allRules = config.rules as QCRule[];
const applicableRules = this.filterApplicableRules(allRules, eventName, forms);
if (applicableRules.length > 0) {
const result = this.executeHardRulesDirectly(applicableRules, recordId, data);
issues.push(...result.issues);
status = result.status;
}
}
} else if (ruleType === 'LLM_CHECK') {
// 使用 SoftRuleEngine
const engine = createSoftRuleEngine(this.projectId, {
model: config?.model || 'deepseek-v3',
});
// V3.1: 过滤适用于当前事件/表单的检查
const rawChecks = config?.checks || [];
const applicableChecks = this.filterApplicableRules(rawChecks, eventName, forms);
const checks: SoftRuleCheck[] = applicableChecks.map((check: any) => ({
id: check.id,
name: check.name || check.desc,
description: check.desc,
promptTemplate: check.promptTemplate || check.prompt,
requiredTags: check.requiredTags || skill.requiredTags || [],
category: check.category || 'medical_logic',
severity: check.severity || 'warning',
applicableEvents: check.applicableEvents,
applicableForms: check.applicableForms,
}));
if (checks.length > 0) {
const result = await engine.execute(recordId, data, checks);
for (const checkResult of result.results) {
if (checkResult.status !== 'PASS') {
issues.push({
ruleId: checkResult.checkId,
ruleName: checkResult.checkName,
message: checkResult.reason,
severity: checkResult.severity,
evidence: checkResult.evidence,
confidence: checkResult.confidence,
});
}
}
if (result.overallStatus === 'FAIL') {
status = 'FAIL';
} else if (result.overallStatus === 'UNCERTAIN') {
status = 'UNCERTAIN';
}
}
} else if (ruleType === 'HYBRID') {
// 混合模式:先执行硬规则,再执行软规则
// TODO: 实现混合逻辑
logger.warn('[SkillRunner] Hybrid rules not yet implemented', {
skillId: skill.id,
});
}
} catch (error: any) {
logger.error('[SkillRunner] Skill execution error', {
skillId: skill.id,
error: error.message,
});
status = 'UNCERTAIN';
issues.push({
ruleId: 'EXECUTION_ERROR',
ruleName: '执行错误',
message: `Skill 执行出错: ${error.message}`,
severity: 'warning',
});
}
const executionTimeMs = Date.now() - startTime;
return {
skillId: skill.id,
skillName: skill.name,
skillType: skill.skillType,
ruleType,
status,
issues,
executionTimeMs,
};
}
/**
* 直接执行硬规则(不通过 HardRuleEngine 初始化)
*
* V2.1 优化:添加 expectedValue, llmMessage, evidence 字段
*/
private executeHardRulesDirectly(
rules: QCRule[],
recordId: string,
data: Record<string, any>
): { status: 'PASS' | 'FAIL' | 'WARNING'; issues: SkillIssue[] } {
const issues: SkillIssue[] = [];
let hasFail = false;
let hasWarning = false;
for (const rule of rules) {
try {
const passed = jsonLogic.apply(rule.logic, data);
if (!passed) {
const severity = rule.severity === 'error' ? 'critical' : 'warning';
const actualValue = this.getFieldValue(rule.field, data);
// V2.1: 提取期望值
const expectedValue = this.extractExpectedValue(rule.logic);
// V2.1: 构建自包含的 LLM 友好消息
const llmMessage = this.buildLlmMessage(rule, actualValue, expectedValue);
issues.push({
ruleId: rule.id,
ruleName: rule.name,
field: rule.field,
message: rule.message,
llmMessage, // V2.1: 自包含消息
severity,
actualValue,
expectedValue, // V2.1: 期望值
evidence: { // V2.1: 结构化证据
value: actualValue,
threshold: expectedValue,
unit: (rule.metadata as any)?.unit,
},
});
if (severity === 'critical') {
hasFail = true;
} else {
hasWarning = true;
}
}
} catch (error: any) {
logger.warn('[SkillRunner] Rule execution error', {
ruleId: rule.id,
error: error.message,
});
}
}
let status: 'PASS' | 'FAIL' | 'WARNING' = 'PASS';
if (hasFail) {
status = 'FAIL';
} else if (hasWarning) {
status = 'WARNING';
}
return { status, issues };
}
/**
* V2.1: 从 JSON Logic 中提取期望值
*/
private extractExpectedValue(logic: Record<string, any>): string {
const operator = Object.keys(logic)[0];
const args = logic[operator];
switch (operator) {
case '>=':
case '<=':
case '>':
case '<':
case '==':
case '!=':
return String(args[1]);
case 'and':
// 对于 and 逻辑,尝试提取范围
if (Array.isArray(args)) {
const values = args.map((a: any) => this.extractExpectedValue(a)).filter(Boolean);
if (values.length === 2) {
return `${values[0]}-${values[1]}`;
}
return values.join(', ');
}
return '';
case '!!':
return '非空/必填';
default:
return '';
}
}
/**
* V2.1: 构建 LLM 友好的自包含消息
*/
private buildLlmMessage(rule: QCRule, actualValue: any, expectedValue: string): string {
const displayValue = actualValue !== undefined && actualValue !== null && actualValue !== ''
? `**${actualValue}**`
: '**空**';
if (expectedValue) {
return `**${rule.name}**: 当前值 ${displayValue} (标准: ${expectedValue})`;
}
return `**${rule.name}**: 当前值 ${displayValue}`;
}
/**
* 获取字段值
*/
private getFieldValue(field: string | string[], data: Record<string, any>): any {
if (Array.isArray(field)) {
return field.map(f => data[f]);
}
return data[field];
}
/**
* V3.1: 过滤适用于当前事件/表单的规则
*
* 规则配置可以包含:
* - applicableEvents: 适用的事件列表(空数组或不设置表示适用所有事件)
* - applicableForms: 适用的表单列表(空数组或不设置表示适用所有表单)
*
* @param rules 所有规则
* @param eventName 当前事件名称
* @param forms 当前事件包含的表单列表
* @returns 适用于当前事件/表单的规则
*/
private filterApplicableRules<T extends { applicableEvents?: string[]; applicableForms?: string[] }>(
rules: T[],
eventName: string,
forms: string[]
): T[] {
return rules.filter(rule => {
// 检查事件是否适用
const eventMatch = !rule.applicableEvents ||
rule.applicableEvents.length === 0 ||
rule.applicableEvents.includes(eventName);
if (!eventMatch) {
return false;
}
// 检查表单是否适用
const formMatch = !rule.applicableForms ||
rule.applicableForms.length === 0 ||
rule.applicableForms.some(f => forms.includes(f));
return formMatch;
});
}
/**
* 保存质控日志
*/
private async saveQcLog(result: SkillRunResult): Promise<void> {
try {
// 将结果保存到 iit_qc_logs 表
// V3.1: 包含事件信息
const issuesWithSummary = {
items: result.allIssues,
summary: result.summary,
// V3.1: 事件级质控元数据
eventLabel: result.eventLabel,
forms: result.forms,
};
await prisma.iitQcLog.create({
data: {
projectId: result.projectId,
recordId: result.recordId,
eventId: result.eventName, // V3.1: 保存事件标识
qcType: 'event', // V3.1: 事件级质控
formName: result.forms?.join(',') || null, // 该事件包含的表单
status: result.overallStatus,
issues: JSON.parse(JSON.stringify(issuesWithSummary)), // 转换为 JSON 兼容格式
ruleVersion: 'v3.1', // V3.1: 事件级质控版本
rulesEvaluated: result.summary.totalSkills || 0,
rulesPassed: result.summary.passed || 0,
rulesFailed: result.summary.failed || 0,
rulesSkipped: 0,
triggeredBy: result.triggerType,
createdAt: new Date(result.timestamp),
},
});
} catch (error: any) {
logger.error('[SkillRunner] Failed to save QC log', {
recordId: result.recordId,
error: error.message,
});
}
}
}
// ============================================================
// 工厂函数
// ============================================================
/**
* 创建 SkillRunner 实例
*
* @param projectId 项目ID
* @returns SkillRunner 实例
*/
export function createSkillRunner(projectId: string): SkillRunner {
return new SkillRunner(projectId);
}