Files
HaHafeng 6edfad032f feat(ssa): finalize strict stepwise agent execution flow
Align Agent mode to strict stepwise generation and execution, add deterministic and safety hardening, and sync deployment/module documentation for Phase 5A.5/5B/5C rollout.

- implement strict stepwise execution path and dependency short-circuiting
- persist step-level errors/results and stream step_* progress events
- add agent plan params patch route and schema/migration support
- improve R sanitizer/security checks and step result rendering in workspace
- update SSA module guide and deployment change checklist

Made-with: Cursor
2026-03-11 22:49:05 +08:00

498 lines
16 KiB
R
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# plumber.R
# SSA-Pro R Statistics Service 入口文件
#
# 安全与性能优化:
# - 生产环境预加载所有工具脚本
# - tool_code 白名单正则校验(防止路径遍历攻击)
library(plumber)
library(jsonlite)
# 环境配置
DEV_MODE <- Sys.getenv("DEV_MODE", "false") == "true"
# 空值合并操作符(避免 `%||%` 未定义导致 execute-code 入口报错)
`%||%` <- function(x, y) if (is.null(x)) y else x
# 加载公共函数
source("utils/error_codes.R")
source("utils/data_loader.R")
source("utils/guardrails.R")
source("utils/result_formatter.R")
source("utils/block_helpers.R")
# 工具目录
tools_dir <- "tools"
tool_files <- list.files(tools_dir, pattern = "\\.R$", full.names = TRUE)
# ========== 生产环境预加载优化 ==========
# 在服务启动时预加载所有工具脚本到独立环境
# 避免每次请求都从磁盘读取和解析
# 工具缓存环境
TOOL_CACHE <- new.env(parent = emptyenv())
# 预加载函数
preload_tools <- function() {
message("[Init] 预加载工具脚本...")
for (f in tool_files) {
tool_name <- tools::file_path_sans_ext(basename(f))
# 创建独立环境加载工具
tool_env <- new.env(parent = globalenv())
source(f, local = tool_env)
# 检查是否实现了 run_analysis
if (exists("run_analysis", envir = tool_env, mode = "function")) {
TOOL_CACHE[[tool_name]] <- tool_env$run_analysis
message(paste("[Init] 已加载:", tool_name))
} else {
warning(paste("[Init] 工具缺少 run_analysis 函数:", tool_name))
}
}
message(paste("[Init] 预加载完成,共", length(ls(TOOL_CACHE)), "个工具"))
}
# 生产环境:启动时预加载
# 开发环境:跳过(支持热重载)
if (!DEV_MODE) {
preload_tools()
} else {
message("[Init] DEV_MODE 启用,跳过预加载(支持热重载)")
# 开发模式仍需首次加载
for (f in tool_files) source(f)
}
# ========== 安全校验函数 ==========
#' 校验 tool_code 格式(防止路径遍历攻击)
#' @param tool_code 工具代码
#' @return TRUE 如果格式合法,否则 FALSE
validate_tool_code <- function(tool_code) {
# 只允许:大写字母、数字、下划线
# 有效示例ST_T_TEST_IND, ST_ANOVA, T_TEST_IND
# 无效示例:../etc/passwd, ST_TEST;rm -rf
pattern <- "^[A-Z][A-Z0-9_]*$"
return(grepl(pattern, tool_code))
}
#' 将 tool_code 转换为工具名(小写,去除 ST_ 前缀)
#' @param tool_code 例如 "ST_T_TEST_IND"
#' @return 例如 "t_test_ind"
normalize_tool_name <- function(tool_code) {
name <- tolower(gsub("^ST_", "", tool_code))
return(name)
}
# ========== API 定义 ==========
#* @apiTitle SSA-Pro R Statistics Service
#* @apiDescription 严谨型统计分析 R 引擎
#* 健康检查
#* @get /health
function() {
list(
status = "ok",
timestamp = Sys.time(),
version = "1.0.1",
dev_mode = DEV_MODE,
tools_loaded = if (DEV_MODE) length(tool_files) else length(ls(TOOL_CACHE))
)
}
#* 列出已加载的工具
#* @get /api/v1/tools
function() {
if (DEV_MODE) {
tools <- gsub("\\.R$", "", basename(tool_files))
} else {
tools <- ls(TOOL_CACHE)
}
list(
status = "ok",
tools = tools,
count = length(tools)
)
}
#* 诊断:返回 R 运行时包清单(只读)
#* @get /api/v1/debug/packages
#* @serializer unboxedJSON
function() {
required_packages <- c(
"plumber", "jsonlite", "ggplot2", "glue", "dplyr", "tidyr",
"base64enc", "yaml", "car", "httr", "scales", "gridExtra",
"gtsummary", "gt", "broom", "meta"
)
installed <- rownames(installed.packages())
missing <- setdiff(required_packages, installed)
list(
status = "ok",
r_version = R.version.string,
dev_mode = DEV_MODE,
lib_paths = .libPaths(),
required_count = length(required_packages),
installed_count = length(installed),
missing_required = missing,
required_status = if (length(missing) == 0) "complete" else "incomplete",
sample_installed = head(sort(installed), 120)
)
}
#* JIT Guardrails Check
#* @post /api/v1/guardrails/jit
#* @serializer unboxedJSON
function(req) {
tryCatch({
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
# 必需参数
tool_code <- input$tool_code
params <- input$params
if (is.null(tool_code)) {
return(list(
status = "error",
error_code = "E400",
message = "Missing tool_code parameter"
))
}
# 加载数据
df <- tryCatch(
load_input_data(input),
error = function(e) {
return(NULL)
}
)
if (is.null(df)) {
return(list(
status = "error",
error_code = "E100",
message = "Failed to load data for guardrail checks"
))
}
# 执行 JIT 护栏检查
result <- run_jit_guardrails(df, tool_code, params)
return(list(
status = "success",
checks = result$checks,
suggested_tool = result$suggested_tool,
can_proceed = result$can_proceed,
all_checks_passed = result$all_checks_passed
))
}, error = function(e) {
return(map_r_error(e$message))
})
}
#* Agent 通道:执行任意 R 代码(沙箱模式)
#* @post /api/v1/execute-code
#* @serializer unboxedJSON
function(req) {
tryCatch({
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
code <- input$code
session_id <- input$session_id
timeout_sec <- as.numeric(input$timeout %||% 120)
if (is.null(code) || nchar(trimws(code)) == 0) {
return(list(
status = "error",
error_code = "E400",
message = "Missing 'code' parameter",
user_hint = "R 代码不能为空"
))
}
# 安全限制:最长 120 秒
if (timeout_sec > 120) timeout_sec <- 120
message(glue::glue("[ExecuteCode] session={session_id}, code_length={nchar(code)}, timeout={timeout_sec}s"))
# ── AST + 安全双层预检:语法检查 + 危险调用拦截 ──
parsed_code <- NULL
ast_check <- tryCatch({
parsed_code <<- parse(text = code)
NULL
}, error = function(e) {
e$message
})
if (!is.null(ast_check)) {
line_match <- regmatches(ast_check, regexpr("\\d+:\\d+", ast_check))
error_line <- if (length(line_match) > 0) as.integer(sub(":.*", "", line_match)) else NULL
code_lines <- strsplit(code, "\n")[[1]]
context_lines <- if (!is.null(error_line) && error_line > 0 && error_line <= length(code_lines)) {
start_l <- max(1, error_line - 2)
end_l <- min(length(code_lines), error_line + 2)
paste(sprintf("%3d| %s", start_l:end_l, code_lines[start_l:end_l]), collapse = "\n")
} else NULL
return(list(
status = "error",
error_code = "E_SYNTAX",
error_type = "syntax",
message = paste0("R 语法错误(代码无法解析): ", ast_check),
user_hint = "代码存在语法错误,请检查括号/引号是否匹配、运算符是否正确",
error_line = error_line,
error_context = context_lines,
console_output = list(),
duration_ms = 0
))
}
# 安全预检静态扫描MVP
# 注:为减少误报,先粗略移除注释行再扫描
code_for_scan <- gsub("(?m)^\\s*#.*$", "", code, perl = TRUE)
forbidden_pattern <- "(^|[^[:alnum:]_\\.])((base::)?system|(base::)?eval|(base::)?parse|(base::)?source|file\\.remove|setwd|download\\.file|readLines|writeLines)\\s*\\("
security_hit <- regexpr(forbidden_pattern, code_for_scan, perl = TRUE, ignore.case = TRUE)
if (security_hit[1] != -1) {
hit_text <- regmatches(code_for_scan, security_hit)[1]
return(list(
status = "error",
error_code = "E_SECURITY",
error_type = "security",
message = paste0("Security Violation: Detected forbidden function call: ", hit_text),
user_hint = "代码包含高风险函数调用(如 system/eval/source/file.remove/setwd已被系统拦截",
console_output = list(),
duration_ms = 0
))
}
sandbox_env <- new.env(parent = globalenv())
# 运行时保护:即使静态扫描漏检,也在沙箱层阻断关键高风险调用
sandbox_env$system <- function(...) stop("Security Violation: function 'system' is forbidden.")
sandbox_env$eval <- function(...) stop("Security Violation: function 'eval' is forbidden.")
sandbox_env$source <- function(...) stop("Security Violation: function 'source' is forbidden.")
sandbox_env$setwd <- function(...) stop("Security Violation: function 'setwd' is forbidden.")
sandbox_env$file.remove <- function(...) stop("Security Violation: function 'file.remove' is forbidden.")
if (!is.null(session_id) && nchar(session_id) > 0) {
sandbox_env$SESSION_ID <- session_id
}
start_time <- proc.time()
collected_warnings <- list()
collected_messages <- character(0)
output_capture <- tryCatch(
withTimeout(
{
captured_output <- utils::capture.output({
result <- withCallingHandlers(
eval(parsed_code, envir = sandbox_env),
warning = function(w) {
collected_warnings[[length(collected_warnings) + 1]] <<- w$message
invokeRestart("muffleWarning")
},
message = function(m) {
collected_messages <<- c(collected_messages, conditionMessage(m))
invokeRestart("muffleMessage")
}
)
})
list(
result = result,
output = captured_output,
warnings = collected_warnings,
messages = collected_messages,
error = NULL
)
},
timeout = timeout_sec,
onTimeout = "error"
),
error = function(e) {
error_info <- format_agent_error(e, code, collected_warnings, collected_messages)
list(
result = NULL,
output = NULL,
warnings = collected_warnings,
messages = collected_messages,
error = error_info$message,
error_detail = error_info
)
}
)
elapsed_ms <- round((proc.time() - start_time)["elapsed"] * 1000)
if (!is.null(output_capture$error)) {
detail <- output_capture$error_detail
message(glue::glue("[ExecuteCode] ERROR after {elapsed_ms}ms: {output_capture$error}"))
console_lines <- c(
output_capture$output,
if (length(output_capture$warnings) > 0) paste0("Warning: ", output_capture$warnings),
if (length(output_capture$messages) > 0) output_capture$messages
)
return(list(
status = "error",
error_code = if (!is.null(detail)) detail$error_code else "E_EXEC",
error_type = if (!is.null(detail)) detail$error_type else "runtime",
message = output_capture$error,
user_hint = if (!is.null(detail)) detail$user_hint else output_capture$error,
error_line = if (!is.null(detail)) detail$error_line else NULL,
error_context = if (!is.null(detail)) detail$error_context else NULL,
console_output = console_lines,
duration_ms = elapsed_ms
))
}
message(glue::glue("[ExecuteCode] SUCCESS in {elapsed_ms}ms"))
final_result <- output_capture$result
console_lines <- c(
output_capture$output,
if (length(output_capture$warnings) > 0) paste0("Warning: ", output_capture$warnings),
if (length(output_capture$messages) > 0) output_capture$messages
)
if (is.list(final_result) && !is.null(final_result$report_blocks)) {
return(list(
status = "success",
result = final_result,
console_output = console_lines,
duration_ms = elapsed_ms
))
}
return(list(
status = "success",
result = list(
data = final_result,
report_blocks = list()
),
console_output = console_lines,
duration_ms = elapsed_ms
))
}, error = function(e) {
message(glue::glue("[ExecuteCode] FATAL ERROR: {e$message}"))
return(map_r_error(e$message))
})
}
#' 超时执行包装器
#' @param expr 表达式
#' @param timeout 超时秒数
#' @param onTimeout 超时行为
withTimeout <- function(expr, timeout = 120, onTimeout = "error") {
setTimeLimit(cpu = timeout, elapsed = timeout, transient = TRUE)
on.exit(setTimeLimit(cpu = Inf, elapsed = Inf, transient = FALSE))
eval(expr, envir = parent.frame())
}
#* 执行统计工具
#* @post /api/v1/skills/<tool_code>
#* @param tool_code:str 工具代码(如 ST_T_TEST_IND
#* @serializer unboxedJSON
function(req, tool_code) {
tryCatch({
# ===== 安全校验tool_code 白名单 =====
if (!validate_tool_code(tool_code)) {
return(list(
status = "error",
error_code = "E400",
message = "Invalid tool code format",
user_hint = "工具代码格式错误,只允许大写字母、数字和下划线"
))
}
# 解析请求体
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
# 记录传入参数(便于调试)
param_names <- if (!is.null(input$params)) paste(names(input$params), collapse=", ") else "NULL"
message(glue::glue("[Skill:{tool_code}] params keys: [{param_names}]"))
if (!is.null(input$params$variables)) {
message(glue::glue("[Skill:{tool_code}] variables ({length(input$params$variables)}): [{paste(input$params$variables, collapse=', ')}]"))
}
if (!is.null(input$params$group_var)) {
message(glue::glue("[Skill:{tool_code}] group_var: {input$params$group_var}"))
}
# Debug 模式:保留临时文件用于排查
debug_mode <- isTRUE(input$debug)
# 统一入口函数名
func_name <- "run_analysis"
# 标准化工具名
tool_name <- normalize_tool_name(tool_code)
tool_file <- file.path("tools", paste0(tool_name, ".R"))
# ===== 根据环境选择加载策略 =====
if (DEV_MODE) {
# 开发环境:每次请求重新加载(支持热重载)
if (!file.exists(tool_file)) {
return(list(
status = "error",
error_code = "E100",
message = paste("Unknown tool:", tool_code),
user_hint = "请检查工具代码是否正确"
))
}
source(tool_file)
if (!exists(func_name, mode = "function")) {
return(list(
status = "error",
error_code = "E100",
message = paste("Tool", tool_code, "does not implement run_analysis()"),
user_hint = "工具脚本格式错误,请联系管理员"
))
}
# 执行分析
result <- do.call(func_name, list(input))
} else {
# 生产环境:从缓存加载
if (!exists(tool_name, envir = TOOL_CACHE)) {
return(list(
status = "error",
error_code = "E100",
message = paste("Unknown tool:", tool_code),
user_hint = "请检查工具代码是否正确,或联系管理员确认工具已部署"
))
}
# 从缓存获取函数并执行
cached_func <- TOOL_CACHE[[tool_name]]
result <- cached_func(input)
}
# Debug 模式:附加临时文件路径
if (debug_mode && !is.null(result$tmp_files)) {
result$debug_files <- result$tmp_files
message("[DEBUG] 临时文件已保留: ", paste(result$tmp_files, collapse = ", "))
}
return(result)
}, error = function(e) {
message(glue::glue("[Skill:{tool_code}] ERROR: {e$message}"))
return(map_r_error(e$message))
})
}