Files
AIclinicalresearch/r-statistics-service/utils/guardrails.R
HaHafeng 3446909ff7 feat(ssa): Complete Phase I-IV intelligent dialogue and tool system development
Phase I - Session Blackboard + READ Layer:
- SessionBlackboardService with Postgres-Only cache
- DataProfileService for data overview generation
- PicoInferenceService for LLM-driven PICO extraction
- Frontend DataContextCard and VariableDictionaryPanel
- E2E tests: 31/31 passed

Phase II - Conversation Layer LLM + Intent Router:
- ConversationService with SSE streaming
- IntentRouterService (rule-first + LLM fallback, 6 intents)
- SystemPromptService with 6-segment dynamic assembly
- TokenTruncationService for context management
- ChatHandlerService as unified chat entry
- Frontend SSAChatPane and useSSAChat hook
- E2E tests: 38/38 passed

Phase III - Method Consultation + AskUser Standardization:
- ToolRegistryService with Repository Pattern
- MethodConsultService with DecisionTable + LLM enhancement
- AskUserService with global interrupt handling
- Frontend AskUserCard component
- E2E tests: 13/13 passed

Phase IV - Dialogue-Driven Analysis + QPER Integration:
- ToolOrchestratorService (plan/execute/report)
- analysis_plan SSE event for WorkflowPlan transmission
- Dual-channel confirmation (ask_user card + workspace button)
- PICO as optional hint for LLM parsing
- E2E tests: 25/25 passed

R Statistics Service:
- 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon
- Enhanced guardrails and block helpers
- Comprehensive test suite (run_all_tools_test.js)

Documentation:
- Updated system status document (v5.9)
- Updated SSA module status and development plan (v1.8)

Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-22 18:53:39 +08:00

335 lines
10 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# utils/guardrails.R
# 统计护栏函数库
library(glue)
# 大样本优化阈值
LARGE_SAMPLE_THRESHOLD <- 5000
# 护栏 Action 类型
ACTION_BLOCK <- "Block" # 阻止执行
ACTION_WARN <- "Warn" # 警告但继续
ACTION_SWITCH <- "Switch" # 切换到备选方法
# 正态性检验(支持三种 Action
check_normality <- function(values, alpha = 0.05, action = ACTION_SWITCH, action_target = NULL) {
n <- length(values)
# 样本量过小
if (n < 3) {
return(list(
passed = TRUE,
action = NULL,
action_target = NULL,
reason = "样本量过小,跳过正态性检验",
skipped = TRUE
))
}
# 大样本优化N > 5000 时使用抽样检验
if (n > LARGE_SAMPLE_THRESHOLD) {
set.seed(42)
sampled_values <- sample(values, 1000)
test <- shapiro.test(sampled_values)
passed <- test$p.value >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
action_target = if (passed) NULL else action_target,
p_value = test$p.value,
reason = glue("大样本(N={n})抽样检验,{if (passed) '满足正态性' else '不满足正态性'}"),
sampled = TRUE,
sample_size = 1000
))
}
# 常规检验
test <- shapiro.test(values)
passed <- test$p.value >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
action_target = if (passed) NULL else action_target,
p_value = test$p.value,
reason = if (passed) "满足正态性" else "不满足正态性",
sampled = FALSE
))
}
# 方差齐性检验 (Levene)
check_homogeneity <- function(df, group_var, value_var, alpha = 0.05, action = ACTION_WARN) {
library(car)
formula <- as.formula(paste(value_var, "~", group_var))
test <- leveneTest(formula, data = df)
p_val <- test$`Pr(>F)`[1]
passed <- p_val >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
p_value = p_val,
reason = if (passed) "方差齐性满足" else "方差不齐性"
))
}
# 样本量检验
check_sample_size <- function(n, min_required = 3, action = ACTION_BLOCK) {
passed <- n >= min_required
return(list(
passed = passed,
action = if (passed) NULL else action,
n = n,
reason = if (passed) "样本量充足" else paste0("样本量不足, 需要至少 ", min_required)
))
}
# 执行护栏链(按 check_order 顺序执行)
run_guardrail_chain <- function(guardrail_results) {
warnings <- c()
for (result in guardrail_results) {
if (!result$passed) {
if (result$action == ACTION_BLOCK) {
return(list(
status = "blocked",
reason = result$reason
))
} else if (result$action == ACTION_SWITCH) {
return(list(
status = "switch",
target_tool = result$action_target,
reason = result$reason
))
} else if (result$action == ACTION_WARN) {
warnings <- c(warnings, result$reason)
}
}
}
return(list(
status = "passed",
warnings = warnings
))
}
# ========== JIT 护栏接口Phase 2A ==========
# 用于 WorkflowExecutor 在执行核心工具前调用
#' JIT 护栏检查:执行核心统计前检验假设
#' @param df 数据框
#' @param tool_code 目标工具代码
#' @param params 工具参数group_var, value_var 等)
#' @return list(checks, suggested_tool, can_proceed)
run_jit_guardrails <- function(df, tool_code, params) {
checks <- list()
suggested_tool <- tool_code
can_proceed <- TRUE
# 根据工具类型执行不同的检验
if (tool_code %in% c("ST_T_TEST_IND", "ST_MANN_WHITNEY")) {
# 独立样本比较:需要正态性 + 方差齐性检验
group_var <- params$group_var
value_var <- params$value_var
if (!is.null(group_var) && !is.null(value_var)) {
groups <- unique(df[[group_var]])
# 正态性检验(分组)
for (g in groups) {
vals <- df[df[[group_var]] == g, value_var]
if (length(vals) >= 3) {
norm_result <- check_normality(vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 (组: {g})"),
passed = norm_result$passed,
p_value = norm_result$p_value,
recommendation = if (norm_result$passed) "满足正态性" else "建议使用非参数方法"
)))
if (!norm_result$passed && tool_code == "ST_T_TEST_IND") {
suggested_tool <- "ST_MANN_WHITNEY"
}
}
}
# 方差齐性检验
if (length(groups) == 2) {
tryCatch({
homo_result <- check_homogeneity(df, group_var, value_var, alpha = 0.05)
checks <- c(checks, list(list(
check_name = "方差齐性检验 (Levene)",
passed = homo_result$passed,
p_value = homo_result$p_value,
recommendation = if (homo_result$passed) "方差齐性满足" else "建议使用 Welch 校正"
)))
}, error = function(e) {
message("方差齐性检验失败: ", e$message)
})
}
}
} else if (tool_code == "ST_T_TEST_PAIRED") {
# 配对检验:需要差值正态性检验
before_var <- params$before_var
after_var <- params$after_var
if (!is.null(before_var) && !is.null(after_var)) {
diff_vals <- df[[after_var]] - df[[before_var]]
diff_vals <- diff_vals[!is.na(diff_vals)]
if (length(diff_vals) >= 3) {
norm_result <- check_normality(diff_vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = "差值正态性检验",
passed = norm_result$passed,
p_value = norm_result$p_value,
recommendation = if (norm_result$passed) "差值满足正态性" else "建议使用 Wilcoxon 符号秩检验"
)))
if (!norm_result$passed) {
suggested_tool <- "Wilcoxon signed-rank test"
}
}
}
} else if (tool_code == "ST_CORRELATION") {
# 相关分析:需要双变量正态性检验
var_x <- params$var_x
var_y <- params$var_y
if (!is.null(var_x) && !is.null(var_y)) {
x_vals <- df[[var_x]][!is.na(df[[var_x]])]
y_vals <- df[[var_y]][!is.na(df[[var_y]])]
if (length(x_vals) >= 3) {
norm_x <- check_normality(x_vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 ({var_x})"),
passed = norm_x$passed,
p_value = norm_x$p_value,
recommendation = if (norm_x$passed) "满足正态性" else "建议使用 Spearman 秩相关"
)))
}
if (length(y_vals) >= 3) {
norm_y <- check_normality(y_vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 ({var_y})"),
passed = norm_y$passed,
p_value = norm_y$p_value,
recommendation = if (norm_y$passed) "满足正态性" else "建议使用 Spearman 秩相关"
)))
if (!norm_x$passed || !norm_y$passed) {
suggested_tool <- "ST_CORRELATION (Spearman)"
}
}
}
} else if (tool_code == "ST_ANOVA_ONE") {
group_var <- params$group_var
value_var <- params$value_var
if (!is.null(group_var) && !is.null(value_var)) {
groups <- unique(df[[group_var]])
for (g in groups) {
vals <- df[df[[group_var]] == g, value_var]
vals <- vals[!is.na(vals)]
if (length(vals) >= 3) {
norm_result <- check_normality(vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 (组: {g})"),
passed = norm_result$passed,
p_value = norm_result$p_value,
recommendation = if (norm_result$passed) "满足正态性" else "建议使用 Kruskal-Wallis"
)))
if (!norm_result$passed) {
suggested_tool <- "Kruskal-Wallis (内置于 ST_ANOVA_ONE)"
}
}
}
tryCatch({
homo_result <- check_homogeneity(df, group_var, value_var, alpha = 0.05)
checks <- c(checks, list(list(
check_name = "方差齐性检验 (Levene)",
passed = homo_result$passed,
p_value = homo_result$p_value,
recommendation = if (homo_result$passed) "方差齐性满足" else "建议使用 Welch ANOVA"
)))
}, error = function(e) {
message("方差齐性检验失败: ", e$message)
})
}
} else if (tool_code == "ST_WILCOXON") {
before_var <- params$before_var
after_var <- params$after_var
if (!is.null(before_var) && !is.null(after_var)) {
diff_vals <- df[[after_var]] - df[[before_var]]
diff_vals <- diff_vals[!is.na(diff_vals)]
checks <- c(checks, list(list(
check_name = "配对样本量检查",
passed = length(diff_vals) >= 5,
recommendation = if (length(diff_vals) >= 5) "样本量充足" else "配对样本量不足"
)))
}
} else if (tool_code %in% c("ST_FISHER", "ST_CHI_SQUARE")) {
var1 <- params$var1
var2 <- params$var2
if (!is.null(var1) && !is.null(var2)) {
ct <- table(df[[var1]], df[[var2]])
expected <- tryCatch(chisq.test(ct)$expected, error = function(e) NULL)
if (!is.null(expected)) {
low_pct <- sum(expected < 5) / length(expected)
checks <- c(checks, list(list(
check_name = "期望频数检查",
passed = low_pct <= 0.2,
recommendation = if (low_pct <= 0.2) "期望频数满足卡方检验条件" else "建议使用 Fisher 精确检验"
)))
if (low_pct > 0.2 && tool_code == "ST_CHI_SQUARE") {
suggested_tool <- "ST_FISHER"
}
}
}
} else if (tool_code == "ST_LINEAR_REG") {
outcome_var <- params$outcome_var
predictors <- params$predictors
if (!is.null(outcome_var)) {
vals <- df[[outcome_var]][!is.na(df[[outcome_var]])]
if (length(vals) >= 3) {
norm_result <- check_normality(vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("结局变量正态性 ({outcome_var})"),
passed = norm_result$passed,
p_value = norm_result$p_value,
recommendation = if (norm_result$passed) "满足正态性" else "结局变量分布偏态,结果需谨慎解读"
)))
}
}
}
# 汇总
all_passed <- all(sapply(checks, function(c) c$passed))
return(list(
checks = checks,
suggested_tool = suggested_tool,
can_proceed = TRUE, # 即使检验不通过也允许继续,由用户/LLM 决定
all_checks_passed = all_passed
))
}