feat(ssa): Complete Phase I-IV intelligent dialogue and tool system development
Phase I - Session Blackboard + READ Layer: - SessionBlackboardService with Postgres-Only cache - DataProfileService for data overview generation - PicoInferenceService for LLM-driven PICO extraction - Frontend DataContextCard and VariableDictionaryPanel - E2E tests: 31/31 passed Phase II - Conversation Layer LLM + Intent Router: - ConversationService with SSE streaming - IntentRouterService (rule-first + LLM fallback, 6 intents) - SystemPromptService with 6-segment dynamic assembly - TokenTruncationService for context management - ChatHandlerService as unified chat entry - Frontend SSAChatPane and useSSAChat hook - E2E tests: 38/38 passed Phase III - Method Consultation + AskUser Standardization: - ToolRegistryService with Repository Pattern - MethodConsultService with DecisionTable + LLM enhancement - AskUserService with global interrupt handling - Frontend AskUserCard component - E2E tests: 13/13 passed Phase IV - Dialogue-Driven Analysis + QPER Integration: - ToolOrchestratorService (plan/execute/report) - analysis_plan SSE event for WorkflowPlan transmission - Dual-channel confirmation (ask_user card + workspace button) - PICO as optional hint for LLM parsing - E2E tests: 25/25 passed R Statistics Service: - 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon - Enhanced guardrails and block helpers - Comprehensive test suite (run_all_tools_test.js) Documentation: - Updated system status document (v5.9) - Updated SSA module status and development plan (v1.8) Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25) Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -31,7 +31,12 @@ RUN R -e "install.packages(c( \
|
||||
'base64enc', \
|
||||
'yaml', \
|
||||
'car', \
|
||||
'httr' \
|
||||
'httr', \
|
||||
'scales', \
|
||||
'gridExtra', \
|
||||
'gtsummary', \
|
||||
'gt', \
|
||||
'broom' \
|
||||
), repos='https://cloud.r-project.org/', Ncpus=2)"
|
||||
|
||||
# ===== 安全加固:创建非特权用户 =====
|
||||
|
||||
299
r-statistics-service/tests/run_all_tools_test.js
Normal file
299
r-statistics-service/tests/run_all_tools_test.js
Normal file
@@ -0,0 +1,299 @@
|
||||
/**
|
||||
* SSA R 统计引擎 — 全工具端到端测试
|
||||
*
|
||||
* 覆盖范围:12 个统计工具 + JIT 护栏 + report_blocks 协议验证
|
||||
*
|
||||
* 运行方式:
|
||||
* node r-statistics-service/tests/run_all_tools_test.js
|
||||
*
|
||||
* 前置条件:R 服务容器已启动(docker-compose up -d)
|
||||
*/
|
||||
|
||||
const http = require('http');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const R_URL = process.env.R_SERVICE_URL || 'http://localhost:8082';
|
||||
const TIMEOUT = 60000;
|
||||
|
||||
// ==================== HTTP ====================
|
||||
|
||||
function post(endpoint, body) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const url = new URL(endpoint, R_URL);
|
||||
const payload = JSON.stringify(body);
|
||||
const req = http.request(
|
||||
{ hostname: url.hostname, port: url.port, path: url.pathname, method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json', 'Content-Length': Buffer.byteLength(payload) },
|
||||
timeout: TIMEOUT },
|
||||
(res) => {
|
||||
let data = '';
|
||||
res.on('data', c => (data += c));
|
||||
res.on('end', () => {
|
||||
try { resolve({ status: res.statusCode, body: JSON.parse(data) }); }
|
||||
catch { resolve({ status: res.statusCode, body: data }); }
|
||||
});
|
||||
}
|
||||
);
|
||||
req.on('error', reject);
|
||||
req.on('timeout', () => { req.destroy(); reject(new Error('timeout')); });
|
||||
req.write(payload);
|
||||
req.end();
|
||||
});
|
||||
}
|
||||
|
||||
function get(endpoint) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const url = new URL(endpoint, R_URL);
|
||||
http.get(url, { timeout: TIMEOUT }, (res) => {
|
||||
let data = '';
|
||||
res.on('data', c => (data += c));
|
||||
res.on('end', () => {
|
||||
try { resolve({ status: res.statusCode, body: JSON.parse(data) }); }
|
||||
catch { resolve({ status: res.statusCode, body: data }); }
|
||||
});
|
||||
}).on('error', reject);
|
||||
});
|
||||
}
|
||||
|
||||
// ==================== 测试数据 ====================
|
||||
|
||||
function loadCSV() {
|
||||
const csvPath = path.join(__dirname, '..', '..', 'docs', '03-业务模块', 'SSA-智能统计分析', '05-测试文档', 'test.csv');
|
||||
const lines = fs.readFileSync(csvPath, 'utf-8').trim().split('\n');
|
||||
const headers = lines[0].split(',');
|
||||
return lines.slice(1).map(line => {
|
||||
const vals = line.split(',');
|
||||
const row = {};
|
||||
headers.forEach((h, i) => {
|
||||
const v = vals[i];
|
||||
if (v === '' || v === undefined) row[h] = null;
|
||||
else if (!isNaN(Number(v))) row[h] = Number(v);
|
||||
else row[h] = v;
|
||||
});
|
||||
return row;
|
||||
});
|
||||
}
|
||||
|
||||
function loadJSON(name) {
|
||||
return JSON.parse(fs.readFileSync(path.join(__dirname, name), 'utf-8'));
|
||||
}
|
||||
|
||||
// ==================== 校验 ====================
|
||||
|
||||
function validateBlocks(blocks, toolName) {
|
||||
const issues = [];
|
||||
if (!Array.isArray(blocks)) { issues.push('report_blocks 不是数组'); return issues; }
|
||||
if (blocks.length === 0) { issues.push('report_blocks 为空'); return issues; }
|
||||
const validTypes = ['markdown', 'table', 'image', 'key_value'];
|
||||
blocks.forEach((b, i) => {
|
||||
if (!validTypes.includes(b.type)) issues.push(`blocks[${i}].type 非法: ${b.type}`);
|
||||
if (b.type === 'table') {
|
||||
if (!Array.isArray(b.headers)) issues.push(`blocks[${i}] table 缺少 headers`);
|
||||
if (!Array.isArray(b.rows)) issues.push(`blocks[${i}] table 缺少 rows`);
|
||||
}
|
||||
if (b.type === 'markdown' && !b.content) issues.push(`blocks[${i}] markdown 缺少 content`);
|
||||
if (b.type === 'image' && !b.data) issues.push(`blocks[${i}] image 缺少 data`);
|
||||
if (b.type === 'key_value' && !Array.isArray(b.items)) issues.push(`blocks[${i}] key_value 缺少 items`);
|
||||
});
|
||||
return issues;
|
||||
}
|
||||
|
||||
// ==================== 主测试 ====================
|
||||
|
||||
async function main() {
|
||||
console.log('\n╔══════════════════════════════════════════════════════════╗');
|
||||
console.log('║ SSA R 统计引擎 — 全工具端到端测试 (12 tools + JIT) ║');
|
||||
console.log('║ ' + new Date().toISOString().slice(0, 19) + ' ║');
|
||||
console.log('╚══════════════════════════════════════════════════════════╝\n');
|
||||
|
||||
// 0. 健康检查
|
||||
let toolsLoaded = 0;
|
||||
try {
|
||||
const h = await get('/health');
|
||||
toolsLoaded = h.body.tools_loaded || 0;
|
||||
console.log(`✅ 健康检查通过 version=${h.body.version} tools_loaded=${toolsLoaded} dev_mode=${h.body.dev_mode}\n`);
|
||||
} catch (e) {
|
||||
console.log(`❌ R 服务不可用: ${e.message}\n`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// 0.1 工具列表
|
||||
try {
|
||||
const tl = await get('/api/v1/tools');
|
||||
console.log(`📋 已注册工具 (${tl.body.count}): ${tl.body.tools.join(', ')}\n`);
|
||||
} catch { /* skip */ }
|
||||
|
||||
const csvData = loadCSV();
|
||||
const ds = { type: 'inline', data: csvData };
|
||||
|
||||
const results = [];
|
||||
|
||||
async function run(name, toolCode, body, checks) {
|
||||
const t0 = Date.now();
|
||||
try {
|
||||
const res = await post(`/api/v1/skills/${toolCode}`, body);
|
||||
const ms = Date.now() - t0;
|
||||
const d = res.body;
|
||||
if (d.status === 'success') {
|
||||
const blockIssues = validateBlocks(d.report_blocks, toolCode);
|
||||
const extra = checks ? checks(d) : {};
|
||||
const hasPlots = Array.isArray(d.plots) && d.plots.length > 0;
|
||||
const hasCode = !!d.reproducible_code;
|
||||
const blocksOk = blockIssues.length === 0;
|
||||
const icon = blocksOk ? '✅' : '⚠️';
|
||||
console.log(`${icon} ${name} (${ms}ms) blocks=${(d.report_blocks||[]).length} plots=${hasPlots?'✓':'✗'} code=${hasCode?'✓':'✗'} ${JSON.stringify(extra)}`);
|
||||
if (!blocksOk) blockIssues.forEach(iss => console.log(` ⚠ ${iss}`));
|
||||
results.push({ name, status: 'pass', ms, blocksOk, extra });
|
||||
} else if (d.status === 'blocked') {
|
||||
console.log(`🔒 ${name} (${ms}ms) status=blocked message=${d.message}`);
|
||||
results.push({ name, status: 'blocked', ms });
|
||||
} else {
|
||||
console.log(`❌ ${name} (${ms}ms) error=${d.error_code||''} ${d.message||''}`);
|
||||
results.push({ name, status: 'fail', ms, error: d.message });
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(`❌ ${name} EXCEPTION: ${e.message}`);
|
||||
results.push({ name, status: 'error', error: e.message });
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Phase 2A 工具(原有 7 个) ==========
|
||||
console.log('─'.repeat(60));
|
||||
console.log(' Phase 2A 工具(原有 7 个)');
|
||||
console.log('─'.repeat(60));
|
||||
|
||||
await run('ST_DESCRIPTIVE (描述性统计)', 'ST_DESCRIPTIVE', {
|
||||
data_source: ds,
|
||||
params: { variables: ['age', 'bmi', 'time'], group_var: 'sex' }
|
||||
}, d => ({ groups: Object.keys(d.results?.summary || {}).length }));
|
||||
|
||||
await run('ST_T_TEST_IND (独立样本T检验)', 'ST_T_TEST_IND', {
|
||||
data_source: ds,
|
||||
params: { group_var: 'sex', value_var: 'age' },
|
||||
guardrails: { check_normality: true }
|
||||
}, d => ({ t: d.results?.statistic, p: d.results?.p_value_fmt }));
|
||||
|
||||
await run('ST_MANN_WHITNEY (Mann-Whitney U)', 'ST_MANN_WHITNEY', {
|
||||
data_source: ds,
|
||||
params: { group_var: 'sex', value_var: 'bmi' }
|
||||
}, d => ({ U: d.results?.statistic_U, p: d.results?.p_value_fmt }));
|
||||
|
||||
await run('ST_CHI_SQUARE (卡方检验)', 'ST_CHI_SQUARE', {
|
||||
data_source: ds,
|
||||
params: { var1: 'sex', var2: 'smoke' }
|
||||
}, d => ({ chi2: d.results?.statistic, p: d.results?.p_value_fmt }));
|
||||
|
||||
await run('ST_CORRELATION (相关分析)', 'ST_CORRELATION', {
|
||||
data_source: ds,
|
||||
params: { var_x: 'age', var_y: 'bmi', method: 'auto' }
|
||||
}, d => ({ r: d.results?.statistic, p: d.results?.p_value_fmt }));
|
||||
|
||||
await run('ST_LOGISTIC_BINARY (Logistic回归)', 'ST_LOGISTIC_BINARY', {
|
||||
data_source: ds,
|
||||
params: { outcome_var: 'Yqol', predictors: ['age', 'bmi', 'sex', 'smoke'] }
|
||||
}, d => ({ aic: d.results?.model_fit?.aic, sig: d.results?.coefficients?.filter(c => c.significant)?.length }));
|
||||
|
||||
await run('ST_T_TEST_PAIRED (配对T检验)', 'ST_T_TEST_PAIRED', {
|
||||
data_source: ds,
|
||||
params: { before_var: 'mouth_open', after_var: 'bucal_relax' },
|
||||
guardrails: { check_normality: true }
|
||||
}, d => ({ p: d.results?.p_value_fmt }));
|
||||
|
||||
// ========== Phase Deploy 新工具(5 个) ==========
|
||||
console.log('\n' + '─'.repeat(60));
|
||||
console.log(' Phase Deploy 新工具(5 个)');
|
||||
console.log('─'.repeat(60));
|
||||
|
||||
const fisherData = loadJSON('test_fisher.json');
|
||||
await run('ST_FISHER (Fisher精确检验)', 'ST_FISHER',
|
||||
fisherData,
|
||||
d => ({ p: d.results?.p_value_fmt, or: d.results?.odds_ratio }));
|
||||
|
||||
const anovaData = loadJSON('test_anova_one.json');
|
||||
await run('ST_ANOVA_ONE (单因素方差分析)', 'ST_ANOVA_ONE',
|
||||
anovaData,
|
||||
d => ({ stat: d.results?.statistic, p: d.results?.p_value_fmt, method: d.results?.method }));
|
||||
|
||||
const wilcoxData = loadJSON('test_wilcoxon.json');
|
||||
await run('ST_WILCOXON (Wilcoxon符号秩)', 'ST_WILCOXON',
|
||||
wilcoxData,
|
||||
d => ({ V: d.results?.statistic, p: d.results?.p_value_fmt, r: d.results?.effect_size?.r }));
|
||||
|
||||
const linearData = loadJSON('test_linear_reg.json');
|
||||
await run('ST_LINEAR_REG (线性回归)', 'ST_LINEAR_REG',
|
||||
linearData,
|
||||
d => ({ r2: d.results?.model_fit?.r_squared, f: d.results?.model_fit?.f_statistic }));
|
||||
|
||||
const baselineData = loadJSON('test_baseline_table.json');
|
||||
await run('ST_BASELINE_TABLE (基线特征表)', 'ST_BASELINE_TABLE',
|
||||
baselineData,
|
||||
d => ({
|
||||
sig_vars: d.results?.significant_vars?.length || 0,
|
||||
methods: d.results?.method_info?.length || 0,
|
||||
is_baseline: d.report_blocks?.[0]?.metadata?.is_baseline_table
|
||||
}));
|
||||
|
||||
// ========== JIT 护栏 ==========
|
||||
console.log('\n' + '─'.repeat(60));
|
||||
console.log(' JIT 护栏检查');
|
||||
console.log('─'.repeat(60));
|
||||
|
||||
const jitTests = [
|
||||
{ name: 'JIT for ST_T_TEST_IND', code: 'ST_T_TEST_IND', body: { data_source: ds, tool_code: 'ST_T_TEST_IND', params: { group_var: 'sex', value_var: 'age' } } },
|
||||
{ name: 'JIT for ST_ANOVA_ONE', code: 'ST_ANOVA_ONE', body: { data_source: anovaData.data_source, tool_code: 'ST_ANOVA_ONE', params: anovaData.params } },
|
||||
{ name: 'JIT for ST_FISHER', code: 'ST_FISHER', body: { data_source: fisherData.data_source, tool_code: 'ST_FISHER', params: fisherData.params } },
|
||||
{ name: 'JIT for ST_LINEAR_REG', code: 'ST_LINEAR_REG', body: { data_source: linearData.data_source, tool_code: 'ST_LINEAR_REG', params: linearData.params } },
|
||||
];
|
||||
|
||||
for (const jt of jitTests) {
|
||||
const t0 = Date.now();
|
||||
try {
|
||||
const res = await post('/api/v1/guardrails/jit', jt.body);
|
||||
const ms = Date.now() - t0;
|
||||
const d = res.body;
|
||||
if (d.status === 'success') {
|
||||
console.log(`✅ ${jt.name} (${ms}ms) checks=${d.checks?.length} all_passed=${d.all_checks_passed} suggested=${d.suggested_tool || 'none'}`);
|
||||
results.push({ name: jt.name, status: 'pass', ms });
|
||||
} else {
|
||||
console.log(`❌ ${jt.name} (${ms}ms) ${d.message || ''}`);
|
||||
results.push({ name: jt.name, status: 'fail', ms });
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(`❌ ${jt.name} EXCEPTION: ${e.message}`);
|
||||
results.push({ name: jt.name, status: 'error' });
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 汇总 ==========
|
||||
console.log('\n' + '═'.repeat(60));
|
||||
console.log(' 测试汇总');
|
||||
console.log('═'.repeat(60));
|
||||
|
||||
const pass = results.filter(r => r.status === 'pass').length;
|
||||
const blocked = results.filter(r => r.status === 'blocked').length;
|
||||
const fail = results.filter(r => r.status === 'fail' || r.status === 'error').length;
|
||||
const total = results.length;
|
||||
|
||||
console.log(` 通过: ${pass}/${total} 阻塞: ${blocked} 失败: ${fail}`);
|
||||
if (fail > 0) {
|
||||
console.log('\n 失败项:');
|
||||
results.filter(r => r.status === 'fail' || r.status === 'error').forEach(r => {
|
||||
console.log(` ❌ ${r.name}: ${r.error || 'unknown'}`);
|
||||
});
|
||||
}
|
||||
|
||||
const avgMs = Math.round(results.filter(r => r.ms).reduce((s, r) => s + r.ms, 0) / results.filter(r => r.ms).length);
|
||||
console.log(`\n 平均响应时间: ${avgMs}ms`);
|
||||
console.log('═'.repeat(60));
|
||||
|
||||
if (fail === 0) {
|
||||
console.log('🎉 全部测试通过!R 统计引擎 12 工具 + JIT 护栏就绪。\n');
|
||||
} else {
|
||||
console.log('⚠️ 存在失败项,请检查 R 服务日志。\n');
|
||||
}
|
||||
|
||||
process.exit(fail > 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
main().catch(e => { console.error('测试脚本异常:', e); process.exit(1); });
|
||||
20
r-statistics-service/tests/test_anova_one.json
Normal file
20
r-statistics-service/tests/test_anova_one.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"data_source": {
|
||||
"type": "inline",
|
||||
"data": {
|
||||
"group": ["A", "A", "A", "A", "A", "A", "A", "A", "A", "A",
|
||||
"B", "B", "B", "B", "B", "B", "B", "B", "B", "B",
|
||||
"C", "C", "C", "C", "C", "C", "C", "C", "C", "C"],
|
||||
"score": [23, 25, 27, 24, 22, 26, 28, 21, 29, 24,
|
||||
30, 32, 35, 31, 28, 33, 36, 29, 34, 31,
|
||||
18, 20, 22, 19, 17, 21, 23, 16, 24, 19]
|
||||
}
|
||||
},
|
||||
"params": {
|
||||
"group_var": "group",
|
||||
"value_var": "score"
|
||||
},
|
||||
"guardrails": {
|
||||
"check_normality": true
|
||||
}
|
||||
}
|
||||
36
r-statistics-service/tests/test_baseline_table.json
Normal file
36
r-statistics-service/tests/test_baseline_table.json
Normal file
@@ -0,0 +1,36 @@
|
||||
{
|
||||
"data_source": {
|
||||
"type": "inline",
|
||||
"data": {
|
||||
"group": ["Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug",
|
||||
"Drug", "Drug", "Drug", "Drug", "Drug",
|
||||
"Placebo", "Placebo", "Placebo", "Placebo", "Placebo",
|
||||
"Placebo", "Placebo", "Placebo", "Placebo", "Placebo",
|
||||
"Placebo", "Placebo", "Placebo", "Placebo", "Placebo"],
|
||||
"age": [45, 52, 38, 61, 44, 55, 49, 57, 42, 50,
|
||||
48, 53, 41, 59, 46,
|
||||
47, 51, 39, 58, 43, 54, 50, 56, 41, 49,
|
||||
46, 52, 40, 60, 44],
|
||||
"sex": ["M", "F", "M", "F", "M", "F", "M", "F", "M", "F",
|
||||
"M", "M", "F", "F", "M",
|
||||
"F", "M", "F", "M", "F", "M", "F", "M", "F", "M",
|
||||
"F", "F", "M", "M", "F"],
|
||||
"sbp": [130, 142, 125, 155, 128, 148, 135, 152, 127, 140,
|
||||
132, 145, 126, 150, 133,
|
||||
128, 138, 122, 150, 126, 142, 135, 148, 124, 136,
|
||||
130, 140, 120, 153, 127],
|
||||
"bmi": [24.5, 28.1, 22.3, 30.5, 23.8, 29.2, 25.6, 31.0, 22.0, 27.5,
|
||||
24.8, 29.5, 21.8, 30.2, 25.1,
|
||||
23.8, 27.2, 21.5, 29.8, 22.9, 28.5, 26.0, 30.1, 21.2, 26.8,
|
||||
24.0, 28.8, 20.8, 31.2, 23.5],
|
||||
"smoking": ["Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No",
|
||||
"Yes", "Yes", "No", "No", "Yes",
|
||||
"No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes",
|
||||
"No", "No", "Yes", "Yes", "No"]
|
||||
}
|
||||
},
|
||||
"params": {
|
||||
"group_var": "group",
|
||||
"analyze_vars": ["age", "sex", "sbp", "bmi", "smoking"]
|
||||
}
|
||||
}
|
||||
15
r-statistics-service/tests/test_fisher.json
Normal file
15
r-statistics-service/tests/test_fisher.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"data_source": {
|
||||
"type": "inline",
|
||||
"data": {
|
||||
"treatment": ["Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug", "Drug",
|
||||
"Placebo", "Placebo", "Placebo", "Placebo", "Placebo", "Placebo", "Placebo", "Placebo", "Placebo", "Placebo"],
|
||||
"outcome": ["Improved", "Improved", "Improved", "Improved", "Improved", "Improved", "Improved", "Not improved", "Not improved", "Not improved",
|
||||
"Improved", "Improved", "Improved", "Not improved", "Not improved", "Not improved", "Not improved", "Not improved", "Not improved", "Not improved"]
|
||||
}
|
||||
},
|
||||
"params": {
|
||||
"var1": "treatment",
|
||||
"var2": "outcome"
|
||||
}
|
||||
}
|
||||
24
r-statistics-service/tests/test_linear_reg.json
Normal file
24
r-statistics-service/tests/test_linear_reg.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"data_source": {
|
||||
"type": "inline",
|
||||
"data": {
|
||||
"sbp": [120, 130, 125, 140, 135, 128, 145, 138, 122, 127,
|
||||
133, 141, 136, 129, 132, 126, 148, 139, 124, 131,
|
||||
137, 143, 134, 128, 150, 142, 123, 130, 136, 144],
|
||||
"age": [25, 35, 30, 45, 40, 32, 50, 42, 28, 33,
|
||||
38, 48, 43, 34, 36, 29, 55, 44, 27, 37,
|
||||
41, 47, 39, 31, 58, 46, 26, 36, 40, 49],
|
||||
"bmi": [22.1, 25.3, 23.5, 28.7, 26.4, 24.0, 30.2, 27.8, 21.5, 23.8,
|
||||
25.6, 29.1, 27.2, 24.5, 25.1, 22.8, 31.5, 28.3, 21.9, 24.9,
|
||||
26.8, 29.5, 26.0, 23.2, 32.1, 28.9, 21.3, 25.0, 26.5, 30.0],
|
||||
"smoke": [0, 1, 0, 1, 1, 0, 1, 1, 0, 0,
|
||||
1, 1, 1, 0, 0, 0, 1, 1, 0, 0,
|
||||
1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
||||
}
|
||||
},
|
||||
"params": {
|
||||
"outcome_var": "sbp",
|
||||
"predictors": ["age", "bmi"],
|
||||
"confounders": ["smoke"]
|
||||
}
|
||||
}
|
||||
15
r-statistics-service/tests/test_wilcoxon.json
Normal file
15
r-statistics-service/tests/test_wilcoxon.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"data_source": {
|
||||
"type": "inline",
|
||||
"data": {
|
||||
"before": [120, 130, 125, 140, 135, 128, 132, 145, 138, 122,
|
||||
127, 133, 141, 136, 129],
|
||||
"after": [115, 122, 118, 130, 125, 120, 126, 135, 128, 118,
|
||||
121, 125, 132, 128, 122]
|
||||
}
|
||||
},
|
||||
"params": {
|
||||
"before_var": "before",
|
||||
"after_var": "after"
|
||||
}
|
||||
}
|
||||
424
r-statistics-service/tools/anova_one.R
Normal file
424
r-statistics-service/tools/anova_one.R
Normal file
@@ -0,0 +1,424 @@
|
||||
#' @tool_code ST_ANOVA_ONE
|
||||
#' @name 单因素方差分析
|
||||
#' @version 1.0.0
|
||||
#' @description 三组及以上独立样本的均值差异比较(含事后多重比较)
|
||||
#' @author SSA-Pro Team
|
||||
|
||||
library(glue)
|
||||
library(ggplot2)
|
||||
library(base64enc)
|
||||
|
||||
run_analysis <- function(input) {
|
||||
# ===== 初始化 =====
|
||||
logs <- c()
|
||||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||||
|
||||
on.exit({}, add = TRUE)
|
||||
|
||||
# ===== 数据加载 =====
|
||||
log_add("开始加载输入数据")
|
||||
df <- tryCatch(
|
||||
load_input_data(input),
|
||||
error = function(e) {
|
||||
log_add(paste("数据加载失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(df)) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||||
}
|
||||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||||
|
||||
p <- input$params
|
||||
guardrails_cfg <- input$guardrails
|
||||
|
||||
group_var <- p$group_var
|
||||
value_var <- p$value_var
|
||||
|
||||
# ===== 参数校验 =====
|
||||
if (!(group_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var))
|
||||
}
|
||||
if (!(value_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = value_var))
|
||||
}
|
||||
|
||||
# ===== 数据清洗 =====
|
||||
original_rows <- nrow(df)
|
||||
df <- df[!is.na(df[[group_var]]) & trimws(as.character(df[[group_var]])) != "", ]
|
||||
df <- df[!is.na(df[[value_var]]), ]
|
||||
|
||||
removed_rows <- original_rows - nrow(df)
|
||||
if (removed_rows > 0) {
|
||||
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
|
||||
}
|
||||
|
||||
# 确保数值型
|
||||
if (!is.numeric(df[[value_var]])) {
|
||||
df[[value_var]] <- as.numeric(as.character(df[[value_var]]))
|
||||
df <- df[!is.na(df[[value_var]]), ]
|
||||
}
|
||||
|
||||
# 分组信息
|
||||
df[[group_var]] <- as.factor(df[[group_var]])
|
||||
groups <- levels(df[[group_var]])
|
||||
n_groups <- length(groups)
|
||||
|
||||
if (n_groups < 3) {
|
||||
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
|
||||
col = group_var, expected = "3+", actual = n_groups))
|
||||
}
|
||||
|
||||
log_add(glue("分组变量 '{group_var}' 有 {n_groups} 个水平: {paste(groups, collapse=', ')}"))
|
||||
|
||||
# ===== 护栏检查 =====
|
||||
guardrail_results <- list()
|
||||
warnings_list <- c()
|
||||
use_kruskal <- FALSE
|
||||
|
||||
# 每组样本量检查
|
||||
group_sizes <- table(df[[group_var]])
|
||||
min_group_n <- min(group_sizes)
|
||||
sample_check <- check_sample_size(min_group_n, min_required = 3, action = ACTION_BLOCK)
|
||||
guardrail_results <- c(guardrail_results, list(sample_check))
|
||||
log_add(glue("最小组样本量: {min_group_n}, {sample_check$reason}"))
|
||||
|
||||
# 正态性检验(每组)
|
||||
if (isTRUE(guardrails_cfg$check_normality)) {
|
||||
log_add("执行正态性检验")
|
||||
normality_failed <- FALSE
|
||||
for (g in groups) {
|
||||
vals <- df[df[[group_var]] == g, value_var]
|
||||
if (length(vals) >= 3 && length(vals) <= 5000) {
|
||||
norm_check <- check_normality(vals, alpha = 0.05, action = ACTION_SWITCH, action_target = "Kruskal-Wallis")
|
||||
guardrail_results <- c(guardrail_results, list(norm_check))
|
||||
log_add(glue("组[{g}] 正态性: p = {round(norm_check$p_value, 4)}, {norm_check$reason}"))
|
||||
if (!norm_check$passed) normality_failed <- TRUE
|
||||
}
|
||||
}
|
||||
if (normality_failed) use_kruskal <- TRUE
|
||||
}
|
||||
|
||||
# 方差齐性检验 (Levene)
|
||||
if (!use_kruskal) {
|
||||
tryCatch({
|
||||
homo_check <- check_homogeneity(df, group_var, value_var, alpha = 0.05, action = ACTION_WARN)
|
||||
guardrail_results <- c(guardrail_results, list(homo_check))
|
||||
log_add(glue("方差齐性 (Levene): p = {round(homo_check$p_value, 4)}, {homo_check$reason}"))
|
||||
if (!homo_check$passed) {
|
||||
warnings_list <- c(warnings_list, "方差不齐性,使用 Welch 校正的 ANOVA")
|
||||
}
|
||||
}, error = function(e) {
|
||||
log_add(paste("方差齐性检验失败:", e$message))
|
||||
})
|
||||
}
|
||||
|
||||
guardrail_status <- run_guardrail_chain(guardrail_results)
|
||||
|
||||
if (guardrail_status$status == "blocked") {
|
||||
return(list(status = "blocked", message = guardrail_status$reason, trace_log = logs))
|
||||
}
|
||||
|
||||
if (guardrail_status$status == "switch") {
|
||||
use_kruskal <- TRUE
|
||||
warnings_list <- c(warnings_list, guardrail_status$reason)
|
||||
log_add(glue("正态性不满足,切换为 Kruskal-Wallis 检验"))
|
||||
}
|
||||
|
||||
if (length(guardrail_status$warnings) > 0) {
|
||||
warnings_list <- c(warnings_list, guardrail_status$warnings)
|
||||
}
|
||||
|
||||
# ===== 各组描述统计 =====
|
||||
group_stats <- lapply(groups, function(g) {
|
||||
vals <- df[df[[group_var]] == g, value_var]
|
||||
list(
|
||||
group = g,
|
||||
n = length(vals),
|
||||
mean = round(mean(vals), 3),
|
||||
sd = round(sd(vals), 3),
|
||||
median = round(median(vals), 3),
|
||||
q1 = round(quantile(vals, 0.25), 3),
|
||||
q3 = round(quantile(vals, 0.75), 3)
|
||||
)
|
||||
})
|
||||
|
||||
# ===== 核心计算 =====
|
||||
if (use_kruskal) {
|
||||
log_add("执行 Kruskal-Wallis 检验")
|
||||
formula_obj <- as.formula(paste(value_var, "~", group_var))
|
||||
result <- kruskal.test(formula_obj, data = df)
|
||||
method_used <- "Kruskal-Wallis rank sum test"
|
||||
stat_name <- "H"
|
||||
|
||||
# 效应量: η² (eta-squared approximation for Kruskal-Wallis)
|
||||
eta_sq <- (result$statistic - n_groups + 1) / (nrow(df) - n_groups)
|
||||
eta_sq <- max(0, as.numeric(eta_sq))
|
||||
|
||||
output_results <- list(
|
||||
method = method_used,
|
||||
statistic = jsonlite::unbox(as.numeric(result$statistic)),
|
||||
statistic_name = stat_name,
|
||||
df = jsonlite::unbox(as.numeric(result$parameter)),
|
||||
p_value = jsonlite::unbox(as.numeric(result$p.value)),
|
||||
p_value_fmt = format_p_value(result$p.value),
|
||||
effect_size = list(
|
||||
eta_squared = jsonlite::unbox(round(eta_sq, 4)),
|
||||
interpretation = interpret_eta_sq(eta_sq)
|
||||
),
|
||||
group_stats = group_stats
|
||||
)
|
||||
|
||||
# 事后多重比较: Dunn test (pairwise Wilcoxon)
|
||||
posthoc_result <- tryCatch({
|
||||
pw <- pairwise.wilcox.test(df[[value_var]], df[[group_var]], p.adjust.method = "bonferroni")
|
||||
pw
|
||||
}, error = function(e) {
|
||||
log_add(paste("Dunn 事后检验失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
} else {
|
||||
log_add("执行单因素 ANOVA")
|
||||
formula_obj <- as.formula(paste(value_var, "~", group_var))
|
||||
|
||||
# 检查方差齐性决定使用经典 ANOVA 还是 Welch ANOVA
|
||||
use_welch <- any(grepl("方差不齐性", warnings_list))
|
||||
|
||||
if (use_welch) {
|
||||
result <- oneway.test(formula_obj, data = df, var.equal = FALSE)
|
||||
method_used <- "One-way ANOVA (Welch correction)"
|
||||
} else {
|
||||
aov_result <- aov(formula_obj, data = df)
|
||||
result_summary <- summary(aov_result)
|
||||
result <- list(
|
||||
statistic = result_summary[[1]]$`F value`[1],
|
||||
parameter = c(result_summary[[1]]$Df[1], result_summary[[1]]$Df[2]),
|
||||
p.value = result_summary[[1]]$`Pr(>F)`[1]
|
||||
)
|
||||
method_used <- "One-way ANOVA"
|
||||
}
|
||||
|
||||
stat_name <- "F"
|
||||
|
||||
# 效应量: η² (eta-squared)
|
||||
ss_between <- sum(tapply(df[[value_var]], df[[group_var]], function(x) length(x) * (mean(x) - mean(df[[value_var]]))^2))
|
||||
ss_total <- sum((df[[value_var]] - mean(df[[value_var]]))^2)
|
||||
eta_sq <- ss_between / ss_total
|
||||
|
||||
f_val <- if (is.list(result)) result$statistic else as.numeric(result$statistic)
|
||||
df_val <- if (is.list(result) && !is.null(result$parameter)) {
|
||||
if (length(result$parameter) == 2) result$parameter else as.numeric(result$parameter)
|
||||
} else {
|
||||
as.numeric(result$parameter)
|
||||
}
|
||||
p_val <- if (is.list(result)) result$p.value else as.numeric(result$p.value)
|
||||
|
||||
output_results <- list(
|
||||
method = method_used,
|
||||
statistic = jsonlite::unbox(as.numeric(f_val)),
|
||||
statistic_name = stat_name,
|
||||
df = if (length(df_val) == 2) as.numeric(df_val) else jsonlite::unbox(as.numeric(df_val)),
|
||||
p_value = jsonlite::unbox(as.numeric(p_val)),
|
||||
p_value_fmt = format_p_value(p_val),
|
||||
effect_size = list(
|
||||
eta_squared = jsonlite::unbox(round(eta_sq, 4)),
|
||||
interpretation = interpret_eta_sq(eta_sq)
|
||||
),
|
||||
group_stats = group_stats
|
||||
)
|
||||
|
||||
# 事后多重比较: Tukey HSD (if classic ANOVA) or pairwise t-test
|
||||
posthoc_result <- tryCatch({
|
||||
if (use_welch) {
|
||||
pairwise.t.test(df[[value_var]], df[[group_var]], p.adjust.method = "bonferroni", pool.sd = FALSE)
|
||||
} else {
|
||||
TukeyHSD(aov(formula_obj, data = df))
|
||||
}
|
||||
}, error = function(e) {
|
||||
log_add(paste("事后多重比较失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
}
|
||||
|
||||
log_add(glue("{stat_name} = {round(as.numeric(output_results$statistic), 3)}, P = {round(as.numeric(output_results$p_value), 4)}"))
|
||||
|
||||
# 整理事后比较结果
|
||||
posthoc_pairs <- NULL
|
||||
if (!is.null(posthoc_result)) {
|
||||
if (inherits(posthoc_result, "TukeyHSD")) {
|
||||
tukey_df <- as.data.frame(posthoc_result[[1]])
|
||||
posthoc_pairs <- lapply(seq_len(nrow(tukey_df)), function(i) {
|
||||
list(
|
||||
comparison = rownames(tukey_df)[i],
|
||||
diff = round(tukey_df$diff[i], 3),
|
||||
ci_lower = round(tukey_df$lwr[i], 3),
|
||||
ci_upper = round(tukey_df$upr[i], 3),
|
||||
p_adj = round(tukey_df$`p adj`[i], 4),
|
||||
p_adj_fmt = format_p_value(tukey_df$`p adj`[i]),
|
||||
significant = tukey_df$`p adj`[i] < 0.05
|
||||
)
|
||||
})
|
||||
} else if (inherits(posthoc_result, "pairwise.htest")) {
|
||||
p_matrix <- posthoc_result$p.value
|
||||
for (i in seq_len(nrow(p_matrix))) {
|
||||
for (j in seq_len(ncol(p_matrix))) {
|
||||
if (!is.na(p_matrix[i, j])) {
|
||||
if (is.null(posthoc_pairs)) posthoc_pairs <- list()
|
||||
posthoc_pairs[[length(posthoc_pairs) + 1]] <- list(
|
||||
comparison = paste(rownames(p_matrix)[i], "vs", colnames(p_matrix)[j]),
|
||||
p_adj = round(p_matrix[i, j], 4),
|
||||
p_adj_fmt = format_p_value(p_matrix[i, j]),
|
||||
significant = p_matrix[i, j] < 0.05
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_results$posthoc <- posthoc_pairs
|
||||
|
||||
# ===== 生成图表 =====
|
||||
log_add("生成箱线图")
|
||||
plot_base64 <- tryCatch({
|
||||
generate_anova_boxplot(df, group_var, value_var)
|
||||
}, error = function(e) {
|
||||
log_add(paste("图表生成失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
|
||||
# ===== 生成可复现代码 =====
|
||||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||||
input$original_filename
|
||||
} else {
|
||||
"data.csv"
|
||||
}
|
||||
|
||||
reproducible_code <- glue('
|
||||
# SSA-Pro 自动生成代码
|
||||
# 工具: 单因素方差分析
|
||||
# 时间: {Sys.time()}
|
||||
# ================================
|
||||
|
||||
library(ggplot2)
|
||||
|
||||
# 数据准备
|
||||
df <- read.csv("{original_filename}")
|
||||
group_var <- "{group_var}"
|
||||
value_var <- "{value_var}"
|
||||
|
||||
# 单因素 ANOVA
|
||||
result <- aov(as.formula(paste(value_var, "~", group_var)), data = df)
|
||||
summary(result)
|
||||
|
||||
# 事后多重比较 (Tukey HSD)
|
||||
TukeyHSD(result)
|
||||
|
||||
# 可视化
|
||||
ggplot(df, aes(x = .data[[group_var]], y = .data[[value_var]], fill = .data[[group_var]])) +
|
||||
geom_boxplot(alpha = 0.7) +
|
||||
theme_minimal() +
|
||||
labs(title = paste("Distribution of", value_var, "by", group_var))
|
||||
')
|
||||
|
||||
# ===== 构建 report_blocks =====
|
||||
blocks <- list()
|
||||
|
||||
# Block 1: 各组描述统计
|
||||
desc_headers <- c("组别", "N", "均值", "标准差", "中位数")
|
||||
desc_rows <- lapply(group_stats, function(gs) {
|
||||
c(gs$group, as.character(gs$n), as.character(gs$mean), as.character(gs$sd), as.character(gs$median))
|
||||
})
|
||||
blocks[[length(blocks) + 1]] <- make_table_block(desc_headers, desc_rows, title = "各组描述统计")
|
||||
|
||||
# Block 2: 检验结果
|
||||
kv_items <- list(
|
||||
"方法" = method_used,
|
||||
"统计量" = paste0(stat_name, " = ", round(as.numeric(output_results$statistic), 3)),
|
||||
"P 值" = output_results$p_value_fmt,
|
||||
"η²" = as.character(output_results$effect_size$eta_squared),
|
||||
"效应量解释" = output_results$effect_size$interpretation
|
||||
)
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(kv_items, title = "检验结果")
|
||||
|
||||
# Block 3: 事后多重比较
|
||||
if (!is.null(posthoc_pairs) && length(posthoc_pairs) > 0) {
|
||||
ph_headers <- c("比较", "P 值 (校正)", "显著性")
|
||||
ph_rows <- lapply(posthoc_pairs, function(pair) {
|
||||
sig <- if (pair$significant) "*" else ""
|
||||
c(pair$comparison, pair$p_adj_fmt, sig)
|
||||
})
|
||||
blocks[[length(blocks) + 1]] <- make_table_block(ph_headers, ph_rows,
|
||||
title = "事后多重比较",
|
||||
footnote = if (use_kruskal) "Bonferroni 校正的 Wilcoxon 检验" else "Tukey HSD / Bonferroni 校正")
|
||||
}
|
||||
|
||||
# Block 4: 箱线图
|
||||
if (!is.null(plot_base64)) {
|
||||
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64,
|
||||
title = paste(value_var, "by", group_var),
|
||||
alt = paste("箱线图:", value_var, "按", group_var, "分组"))
|
||||
}
|
||||
|
||||
# Block 5: 结论摘要
|
||||
p_val_num <- as.numeric(output_results$p_value)
|
||||
sig_text <- if (p_val_num < 0.05) "各组间存在统计学显著差异" else "各组间差异无统计学意义"
|
||||
conclusion <- glue("{method_used}: {stat_name} = {round(as.numeric(output_results$statistic), 3)}, P {output_results$p_value_fmt}。{sig_text}(η² = {output_results$effect_size$eta_squared},{output_results$effect_size$interpretation}效应)。")
|
||||
|
||||
if (!is.null(posthoc_pairs) && p_val_num < 0.05) {
|
||||
sig_pairs <- Filter(function(x) x$significant, posthoc_pairs)
|
||||
if (length(sig_pairs) > 0) {
|
||||
pair_names <- sapply(sig_pairs, function(x) x$comparison)
|
||||
conclusion <- paste0(conclusion, glue("\n\n事后比较显示以下组间差异显著:{paste(pair_names, collapse = '、')}。"))
|
||||
}
|
||||
}
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要")
|
||||
|
||||
# ===== 返回结果 =====
|
||||
log_add("分析完成")
|
||||
|
||||
return(list(
|
||||
status = "success",
|
||||
message = "分析完成",
|
||||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||||
results = output_results,
|
||||
report_blocks = blocks,
|
||||
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
|
||||
trace_log = logs,
|
||||
reproducible_code = as.character(reproducible_code)
|
||||
))
|
||||
}
|
||||
|
||||
# η² 效应量解释
|
||||
interpret_eta_sq <- function(eta_sq) {
|
||||
if (eta_sq < 0.01) return("微小")
|
||||
if (eta_sq < 0.06) return("小")
|
||||
if (eta_sq < 0.14) return("中等")
|
||||
return("大")
|
||||
}
|
||||
|
||||
# NULL 合并运算符
|
||||
`%||%` <- function(x, y) if (is.null(x)) y else x
|
||||
|
||||
# 辅助函数:ANOVA 箱线图
|
||||
generate_anova_boxplot <- function(df, group_var, value_var) {
|
||||
p <- ggplot(df, aes(x = .data[[group_var]], y = .data[[value_var]], fill = .data[[group_var]])) +
|
||||
geom_boxplot(alpha = 0.7, outlier.shape = 21) +
|
||||
stat_summary(fun = mean, geom = "point", shape = 18, size = 3, color = "red") +
|
||||
theme_minimal() +
|
||||
labs(
|
||||
title = paste("Distribution of", value_var, "by", group_var),
|
||||
x = group_var,
|
||||
y = value_var
|
||||
) +
|
||||
scale_fill_brewer(palette = "Set2") +
|
||||
theme(legend.position = "none")
|
||||
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
ggsave(tmp_file, p, width = max(7, length(unique(df[[group_var]])) * 1.5), height = 5, dpi = 100)
|
||||
base64_str <- base64encode(tmp_file)
|
||||
unlink(tmp_file)
|
||||
|
||||
return(paste0("data:image/png;base64,", base64_str))
|
||||
}
|
||||
316
r-statistics-service/tools/baseline_table.R
Normal file
316
r-statistics-service/tools/baseline_table.R
Normal file
@@ -0,0 +1,316 @@
|
||||
#' @tool_code ST_BASELINE_TABLE
|
||||
#' @name 基线特征表(复合工具)
|
||||
#' @version 1.0.0
|
||||
#' @description 基于 gtsummary 的一键式基线特征表生成,自动判断变量类型、选择统计方法、输出标准三线表
|
||||
#' @author SSA-Pro Team
|
||||
#' @note 复合工具:一次遍历所有变量,自动选方法(T/Wilcoxon/χ²/Fisher),合并出表
|
||||
|
||||
library(glue)
|
||||
library(ggplot2)
|
||||
library(base64enc)
|
||||
|
||||
run_analysis <- function(input) {
|
||||
# ===== 初始化 =====
|
||||
logs <- c()
|
||||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||||
warnings_list <- c()
|
||||
|
||||
on.exit({}, add = TRUE)
|
||||
|
||||
# ===== 依赖检查 =====
|
||||
required_pkgs <- c("gtsummary", "gt", "broom")
|
||||
for (pkg in required_pkgs) {
|
||||
if (!requireNamespace(pkg, quietly = TRUE)) {
|
||||
return(make_error(ERROR_CODES$E101_PACKAGE_MISSING, package = pkg))
|
||||
}
|
||||
}
|
||||
|
||||
library(gtsummary)
|
||||
library(dplyr)
|
||||
|
||||
# ===== 数据加载 =====
|
||||
log_add("开始加载输入数据")
|
||||
df <- tryCatch(
|
||||
load_input_data(input),
|
||||
error = function(e) {
|
||||
log_add(paste("数据加载失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(df)) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||||
}
|
||||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||||
|
||||
p <- input$params
|
||||
group_var <- p$group_var
|
||||
analyze_vars <- as.character(unlist(p$analyze_vars))
|
||||
|
||||
# ===== 参数校验 =====
|
||||
if (is.null(group_var) || !(group_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var %||% "NULL"))
|
||||
}
|
||||
|
||||
if (is.null(analyze_vars) || length(analyze_vars) == 0) {
|
||||
analyze_vars <- setdiff(names(df), group_var)
|
||||
log_add(glue("未指定分析变量,自动选取全部 {length(analyze_vars)} 个变量"))
|
||||
}
|
||||
|
||||
missing_vars <- analyze_vars[!(analyze_vars %in% names(df))]
|
||||
if (length(missing_vars) > 0) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND,
|
||||
col = paste(missing_vars, collapse = ", ")))
|
||||
}
|
||||
|
||||
# ===== 数据清洗 =====
|
||||
original_rows <- nrow(df)
|
||||
df <- df[!is.na(df[[group_var]]) & trimws(as.character(df[[group_var]])) != "", ]
|
||||
removed_rows <- original_rows - nrow(df)
|
||||
if (removed_rows > 0) {
|
||||
log_add(glue("分组变量缺失值清洗: 移除 {removed_rows} 行 (剩余 {nrow(df)} 行)"))
|
||||
}
|
||||
|
||||
groups <- unique(df[[group_var]])
|
||||
n_groups <- length(groups)
|
||||
if (n_groups < 2) {
|
||||
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
|
||||
col = group_var, expected = "2+", actual = n_groups))
|
||||
}
|
||||
|
||||
# 样本量检查
|
||||
sample_check <- check_sample_size(nrow(df), min_required = 10, action = ACTION_BLOCK)
|
||||
if (!sample_check$passed) {
|
||||
return(list(status = "blocked", message = sample_check$reason, trace_log = logs))
|
||||
}
|
||||
|
||||
# 确保分组变量是因子
|
||||
df[[group_var]] <- as.factor(df[[group_var]])
|
||||
|
||||
# 选取分析列
|
||||
df_analysis <- df[, c(group_var, analyze_vars), drop = FALSE]
|
||||
|
||||
log_add(glue("分组变量: {group_var} ({n_groups} 组: {paste(groups, collapse=', ')})"))
|
||||
log_add(glue("分析变量: {length(analyze_vars)} 个"))
|
||||
|
||||
# ===== 核心计算:gtsummary =====
|
||||
log_add("使用 gtsummary 生成基线特征表")
|
||||
|
||||
tbl <- tryCatch(
|
||||
withCallingHandlers(
|
||||
{
|
||||
tbl_summary(
|
||||
df_analysis,
|
||||
by = all_of(group_var),
|
||||
missing = "ifany",
|
||||
statistic = list(
|
||||
all_continuous() ~ "{mean} ({sd})",
|
||||
all_categorical() ~ "{n} ({p}%)"
|
||||
),
|
||||
digits = list(
|
||||
all_continuous() ~ 2,
|
||||
all_categorical() ~ c(0, 1)
|
||||
)
|
||||
) %>%
|
||||
add_p() %>%
|
||||
add_overall()
|
||||
},
|
||||
warning = function(w) {
|
||||
warnings_list <<- c(warnings_list, w$message)
|
||||
log_add(paste("gtsummary 警告:", w$message))
|
||||
invokeRestart("muffleWarning")
|
||||
}
|
||||
),
|
||||
error = function(e) {
|
||||
log_add(paste("gtsummary 生成失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(tbl)) {
|
||||
return(map_r_error("gtsummary 基线特征表生成失败"))
|
||||
}
|
||||
|
||||
log_add("gtsummary 表格生成成功")
|
||||
|
||||
# ===== 提取结构化数据 =====
|
||||
tbl_df <- as.data.frame(tbl$table_body)
|
||||
|
||||
# 提取显著变量列表
|
||||
significant_vars <- extract_significant_vars(tbl, alpha = 0.05)
|
||||
log_add(glue("显著变量 (P < 0.05): {length(significant_vars)} 个"))
|
||||
|
||||
# 提取每个变量使用的统计方法
|
||||
method_info <- extract_method_info(tbl)
|
||||
|
||||
# ===== 转换为 report_blocks =====
|
||||
log_add("转换 gtsummary → report_blocks")
|
||||
blocks <- gtsummary_to_blocks(tbl, group_var, groups, analyze_vars, significant_vars)
|
||||
|
||||
# ===== 构建结构化结果 =====
|
||||
output_results <- list(
|
||||
method = "gtsummary::tbl_summary + add_p",
|
||||
group_var = group_var,
|
||||
n_groups = n_groups,
|
||||
groups = lapply(groups, function(g) {
|
||||
list(label = as.character(g), n = sum(df[[group_var]] == g))
|
||||
}),
|
||||
n_variables = length(analyze_vars),
|
||||
significant_vars = significant_vars,
|
||||
method_info = method_info,
|
||||
total_n = nrow(df)
|
||||
)
|
||||
|
||||
# ===== 生成可复现代码 =====
|
||||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||||
input$original_filename
|
||||
} else {
|
||||
"data.csv"
|
||||
}
|
||||
|
||||
vars_str <- paste0('c("', paste(analyze_vars, collapse = '", "'), '")')
|
||||
|
||||
reproducible_code <- glue('
|
||||
# SSA-Pro 自动生成代码
|
||||
# 工具: 基线特征表 (gtsummary)
|
||||
# 时间: {Sys.time()}
|
||||
# ================================
|
||||
|
||||
# 自动安装依赖
|
||||
required_packages <- c("gtsummary", "gt", "dplyr")
|
||||
new_packages <- required_packages[!(required_packages %in% installed.packages()[,"Package"])]
|
||||
if(length(new_packages)) install.packages(new_packages, repos = "https://cloud.r-project.org")
|
||||
|
||||
library(gtsummary)
|
||||
library(dplyr)
|
||||
|
||||
# 数据准备
|
||||
df <- read.csv("{original_filename}")
|
||||
group_var <- "{group_var}"
|
||||
analyze_vars <- {vars_str}
|
||||
|
||||
df_analysis <- df[, c(group_var, analyze_vars)]
|
||||
df_analysis[[group_var]] <- as.factor(df_analysis[[group_var]])
|
||||
|
||||
# 生成基线特征表
|
||||
tbl <- tbl_summary(
|
||||
df_analysis,
|
||||
by = all_of(group_var),
|
||||
missing = "ifany",
|
||||
statistic = list(
|
||||
all_continuous() ~ "{{mean}} ({{sd}})",
|
||||
all_categorical() ~ "{{n}} ({{p}}%)"
|
||||
)
|
||||
) %>%
|
||||
add_p() %>%
|
||||
add_overall()
|
||||
|
||||
# 显示结果
|
||||
tbl
|
||||
|
||||
# 导出为 Word(可选)
|
||||
# tbl %>% as_gt() %>% gt::gtsave("baseline_table.docx")
|
||||
')
|
||||
|
||||
# ===== 返回结果 =====
|
||||
log_add("分析完成")
|
||||
|
||||
return(list(
|
||||
status = "success",
|
||||
message = "基线特征表生成完成",
|
||||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||||
results = output_results,
|
||||
report_blocks = blocks,
|
||||
plots = list(),
|
||||
trace_log = logs,
|
||||
reproducible_code = as.character(reproducible_code)
|
||||
))
|
||||
}
|
||||
|
||||
# ===== gtsummary → report_blocks 转换层 =====
|
||||
|
||||
#' 将 gtsummary 表格转为 report_blocks
|
||||
gtsummary_to_blocks <- function(tbl, group_var, groups, analyze_vars, significant_vars) {
|
||||
blocks <- list()
|
||||
|
||||
# 提取 tibble 格式
|
||||
tbl_data <- gtsummary::as_tibble(tbl, col_labels = FALSE)
|
||||
|
||||
# Block 1: 三线表(核心输出)
|
||||
headers <- colnames(tbl_data)
|
||||
rows <- lapply(seq_len(nrow(tbl_data)), function(i) {
|
||||
row <- as.list(tbl_data[i, ])
|
||||
lapply(row, function(cell) {
|
||||
val <- as.character(cell)
|
||||
if (is.na(val) || val == "NA") "" else val
|
||||
})
|
||||
})
|
||||
|
||||
# 标记 P < 0.05 的行
|
||||
p_col_idx <- which(grepl("p.value|p_value", headers, ignore.case = TRUE))
|
||||
|
||||
blocks[[length(blocks) + 1]] <- make_table_block(
|
||||
headers, rows,
|
||||
title = glue("基线特征表 (按 {group_var} 分组)"),
|
||||
footnote = "连续变量: Mean (SD); 分类变量: N (%); P 值由自动选择的统计方法计算",
|
||||
metadata = list(
|
||||
is_baseline_table = TRUE,
|
||||
group_var = group_var,
|
||||
has_p_values = length(p_col_idx) > 0
|
||||
)
|
||||
)
|
||||
|
||||
# Block 2: 样本量概况
|
||||
group_n_items <- lapply(groups, function(g) {
|
||||
list(key = as.character(g), value = "—")
|
||||
})
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(
|
||||
list("总样本量" = as.character(nrow(tbl$inputs$data)),
|
||||
"分组变量" = group_var,
|
||||
"分组数" = as.character(length(groups)),
|
||||
"分析变量数" = as.character(length(analyze_vars))),
|
||||
title = "样本概况"
|
||||
)
|
||||
|
||||
# Block 3: 显著变量摘要
|
||||
if (length(significant_vars) > 0) {
|
||||
conclusion <- glue("在 α = 0.05 水平下,以下变量在组间存在显著差异:**{paste(significant_vars, collapse = '**、**')}**(共 {length(significant_vars)} 个)。")
|
||||
} else {
|
||||
conclusion <- "在 α = 0.05 水平下,未发现各组间存在显著差异的基线变量。"
|
||||
}
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "组间差异摘要")
|
||||
|
||||
return(blocks)
|
||||
}
|
||||
|
||||
#' 从 gtsummary 提取显著变量
|
||||
extract_significant_vars <- function(tbl, alpha = 0.05) {
|
||||
body <- tbl$table_body
|
||||
p_vals <- as.numeric(unlist(body$p.value))
|
||||
vars <- as.character(body$variable)
|
||||
sig_idx <- which(!is.na(p_vals) & p_vals < alpha)
|
||||
if (length(sig_idx) == 0) return(character(0))
|
||||
unique(vars[sig_idx])
|
||||
}
|
||||
|
||||
#' 提取每个变量使用的统计方法
|
||||
extract_method_info <- function(tbl) {
|
||||
body <- tbl$table_body
|
||||
p_vals <- as.numeric(unlist(body$p.value))
|
||||
has_p <- which(!is.na(p_vals))
|
||||
if (length(has_p) == 0) return(list())
|
||||
|
||||
test_names <- if ("test_name" %in% colnames(body)) as.character(unlist(body$test_name)) else rep("unknown", nrow(body))
|
||||
|
||||
lapply(has_p, function(i) {
|
||||
list(
|
||||
variable = as.character(body$variable[i]),
|
||||
test_name = test_names[i] %||% "unknown",
|
||||
p_value = round(p_vals[i], 4),
|
||||
p_value_fmt = format_p_value(p_vals[i])
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
# NULL 合并运算符
|
||||
`%||%` <- function(x, y) if (is.null(x)) y else x
|
||||
272
r-statistics-service/tools/fisher.R
Normal file
272
r-statistics-service/tools/fisher.R
Normal file
@@ -0,0 +1,272 @@
|
||||
#' @tool_code ST_FISHER
|
||||
#' @name Fisher 精确检验
|
||||
#' @version 1.0.0
|
||||
#' @description 小样本或稀疏列联表的精确独立性检验(卡方检验的替代方法)
|
||||
#' @author SSA-Pro Team
|
||||
|
||||
library(glue)
|
||||
library(ggplot2)
|
||||
library(base64enc)
|
||||
|
||||
run_analysis <- function(input) {
|
||||
# ===== 初始化 =====
|
||||
logs <- c()
|
||||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||||
|
||||
on.exit({}, add = TRUE)
|
||||
|
||||
# ===== 数据加载 =====
|
||||
log_add("开始加载输入数据")
|
||||
df <- tryCatch(
|
||||
load_input_data(input),
|
||||
error = function(e) {
|
||||
log_add(paste("数据加载失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(df)) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||||
}
|
||||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||||
|
||||
p <- input$params
|
||||
var1 <- p$var1
|
||||
var2 <- p$var2
|
||||
|
||||
# ===== 参数校验 =====
|
||||
if (!(var1 %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = var1))
|
||||
}
|
||||
if (!(var2 %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = var2))
|
||||
}
|
||||
|
||||
# ===== 数据清洗 =====
|
||||
original_rows <- nrow(df)
|
||||
df <- df[!is.na(df[[var1]]) & trimws(as.character(df[[var1]])) != "", ]
|
||||
df <- df[!is.na(df[[var2]]) & trimws(as.character(df[[var2]])) != "", ]
|
||||
|
||||
removed_rows <- original_rows - nrow(df)
|
||||
if (removed_rows > 0) {
|
||||
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
|
||||
}
|
||||
|
||||
# ===== 护栏检查 =====
|
||||
guardrail_results <- list()
|
||||
warnings_list <- c()
|
||||
|
||||
sample_check <- check_sample_size(nrow(df), min_required = 4, action = ACTION_BLOCK)
|
||||
guardrail_results <- c(guardrail_results, list(sample_check))
|
||||
log_add(glue("样本量检查: N = {nrow(df)}, {sample_check$reason}"))
|
||||
|
||||
guardrail_status <- run_guardrail_chain(guardrail_results)
|
||||
|
||||
if (guardrail_status$status == "blocked") {
|
||||
return(list(
|
||||
status = "blocked",
|
||||
message = guardrail_status$reason,
|
||||
trace_log = logs
|
||||
))
|
||||
}
|
||||
|
||||
# ===== 构建列联表 =====
|
||||
contingency_table <- table(df[[var1]], df[[var2]])
|
||||
log_add(glue("列联表维度: {nrow(contingency_table)} x {ncol(contingency_table)}"))
|
||||
|
||||
if (nrow(contingency_table) < 2 || ncol(contingency_table) < 2) {
|
||||
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
|
||||
col = paste(var1, "或", var2),
|
||||
expected = 2,
|
||||
actual = min(nrow(contingency_table), ncol(contingency_table))))
|
||||
}
|
||||
|
||||
is_2x2 <- nrow(contingency_table) == 2 && ncol(contingency_table) == 2
|
||||
|
||||
# 期望频数信息(仅供报告)
|
||||
expected <- chisq.test(contingency_table)$expected
|
||||
low_expected_count <- sum(expected < 5)
|
||||
total_cells <- length(expected)
|
||||
low_expected_pct <- low_expected_count / total_cells
|
||||
|
||||
if (low_expected_pct > 0) {
|
||||
log_add(glue("期望频数 < 5 的格子: {low_expected_count}/{total_cells} ({round(low_expected_pct * 100, 1)}%)"))
|
||||
}
|
||||
|
||||
# ===== 核心计算 =====
|
||||
log_add("执行 Fisher 精确检验")
|
||||
|
||||
result <- tryCatch({
|
||||
if (is_2x2) {
|
||||
fisher.test(contingency_table)
|
||||
} else {
|
||||
fisher.test(contingency_table, simulate.p.value = TRUE, B = 10000)
|
||||
}
|
||||
}, error = function(e) {
|
||||
log_add(paste("Fisher 检验失败:", e$message))
|
||||
return(NULL)
|
||||
})
|
||||
|
||||
if (is.null(result)) {
|
||||
return(map_r_error("Fisher 精确检验计算失败,列联表可能过大"))
|
||||
}
|
||||
|
||||
method_used <- result$method
|
||||
|
||||
output_results <- list(
|
||||
method = method_used,
|
||||
p_value = jsonlite::unbox(as.numeric(result$p.value)),
|
||||
p_value_fmt = format_p_value(result$p.value)
|
||||
)
|
||||
|
||||
if (!is.null(result$estimate)) {
|
||||
output_results$odds_ratio = jsonlite::unbox(as.numeric(result$estimate))
|
||||
}
|
||||
if (!is.null(result$conf.int)) {
|
||||
output_results$conf_int = as.numeric(result$conf.int)
|
||||
}
|
||||
|
||||
observed_matrix <- matrix(
|
||||
as.numeric(contingency_table),
|
||||
nrow = nrow(contingency_table),
|
||||
ncol = ncol(contingency_table),
|
||||
dimnames = list(rownames(contingency_table), colnames(contingency_table))
|
||||
)
|
||||
|
||||
output_results$contingency_table <- list(
|
||||
row_var = var1,
|
||||
col_var = var2,
|
||||
row_levels = as.character(rownames(contingency_table)),
|
||||
col_levels = as.character(colnames(contingency_table)),
|
||||
observed = observed_matrix,
|
||||
row_totals = as.numeric(rowSums(contingency_table)),
|
||||
col_totals = as.numeric(colSums(contingency_table)),
|
||||
grand_total = jsonlite::unbox(sum(contingency_table))
|
||||
)
|
||||
|
||||
log_add(glue("P = {round(result$p.value, 4)}"))
|
||||
|
||||
# ===== 生成图表 =====
|
||||
log_add("生成堆叠条形图")
|
||||
plot_base64 <- tryCatch({
|
||||
generate_stacked_bar(contingency_table, var1, var2)
|
||||
}, error = function(e) {
|
||||
log_add(paste("图表生成失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
|
||||
# ===== 生成可复现代码 =====
|
||||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||||
input$original_filename
|
||||
} else {
|
||||
"data.csv"
|
||||
}
|
||||
|
||||
reproducible_code <- glue('
|
||||
# SSA-Pro 自动生成代码
|
||||
# 工具: Fisher 精确检验
|
||||
# 时间: {Sys.time()}
|
||||
# ================================
|
||||
|
||||
library(ggplot2)
|
||||
|
||||
# 数据准备
|
||||
df <- read.csv("{original_filename}")
|
||||
var1 <- "{var1}"
|
||||
var2 <- "{var2}"
|
||||
|
||||
# 数据清洗
|
||||
df <- df[!is.na(df[[var1]]) & !is.na(df[[var2]]), ]
|
||||
|
||||
# 构建列联表
|
||||
contingency_table <- table(df[[var1]], df[[var2]])
|
||||
print(contingency_table)
|
||||
|
||||
# Fisher 精确检验
|
||||
result <- fisher.test(contingency_table)
|
||||
print(result)
|
||||
')
|
||||
|
||||
# ===== 构建 report_blocks =====
|
||||
blocks <- list()
|
||||
|
||||
# Block 1: 列联表
|
||||
table_headers <- c(var1, as.character(colnames(contingency_table)))
|
||||
table_rows <- lapply(seq_len(nrow(contingency_table)), function(i) {
|
||||
c(as.character(rownames(contingency_table)[i]), as.character(contingency_table[i, ]))
|
||||
})
|
||||
blocks[[length(blocks) + 1]] <- make_table_block(table_headers, table_rows, title = "列联表")
|
||||
|
||||
# Block 2: 检验结果
|
||||
kv_items <- list(
|
||||
"方法" = method_used,
|
||||
"P 值" = output_results$p_value_fmt
|
||||
)
|
||||
if (!is.null(output_results$odds_ratio)) {
|
||||
kv_items[["比值比 (OR)"]] <- as.character(round(as.numeric(output_results$odds_ratio), 4))
|
||||
}
|
||||
if (!is.null(output_results$conf_int)) {
|
||||
kv_items[["95% 置信区间"]] <- sprintf("[%.4f, %.4f]", output_results$conf_int[1], output_results$conf_int[2])
|
||||
}
|
||||
if (low_expected_count > 0) {
|
||||
kv_items[["期望频数 < 5 的格子"]] <- glue("{low_expected_count}/{total_cells}")
|
||||
}
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(kv_items, title = "检验结果")
|
||||
|
||||
# Block 3: 图表
|
||||
if (!is.null(plot_base64)) {
|
||||
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64, title = "堆叠条形图",
|
||||
alt = paste("堆叠条形图:", var1, "与", var2, "的关联"))
|
||||
}
|
||||
|
||||
# Block 4: 结论摘要
|
||||
p_val <- as.numeric(output_results$p_value)
|
||||
conclusion <- if (p_val < 0.05) {
|
||||
glue("Fisher 精确检验显示,{var1} 与 {var2} 之间存在显著关联(P {output_results$p_value_fmt})。")
|
||||
} else {
|
||||
glue("Fisher 精确检验显示,未发现 {var1} 与 {var2} 之间的显著关联(P {output_results$p_value_fmt})。")
|
||||
}
|
||||
if (!is.null(output_results$odds_ratio)) {
|
||||
conclusion <- paste0(conclusion, glue(" 比值比 OR = {round(as.numeric(output_results$odds_ratio), 3)}。"))
|
||||
}
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要")
|
||||
|
||||
# ===== 返回结果 =====
|
||||
log_add("分析完成")
|
||||
|
||||
return(list(
|
||||
status = "success",
|
||||
message = "分析完成",
|
||||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||||
results = output_results,
|
||||
report_blocks = blocks,
|
||||
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
|
||||
trace_log = logs,
|
||||
reproducible_code = as.character(reproducible_code)
|
||||
))
|
||||
}
|
||||
|
||||
# 辅助函数:堆叠条形图
|
||||
generate_stacked_bar <- function(contingency_table, var1, var2) {
|
||||
df_plot <- as.data.frame(contingency_table)
|
||||
names(df_plot) <- c("Var1", "Var2", "Freq")
|
||||
|
||||
p <- ggplot(df_plot, aes(x = Var1, y = Freq, fill = Var2)) +
|
||||
geom_bar(stat = "identity", position = "fill") +
|
||||
scale_y_continuous(labels = scales::percent) +
|
||||
theme_minimal() +
|
||||
labs(
|
||||
title = paste("Association:", var1, "vs", var2),
|
||||
x = var1,
|
||||
y = "Proportion",
|
||||
fill = var2
|
||||
) +
|
||||
scale_fill_brewer(palette = "Set2")
|
||||
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
ggsave(tmp_file, p, width = 7, height = 5, dpi = 100)
|
||||
base64_str <- base64encode(tmp_file)
|
||||
unlink(tmp_file)
|
||||
|
||||
return(paste0("data:image/png;base64,", base64_str))
|
||||
}
|
||||
377
r-statistics-service/tools/linear_reg.R
Normal file
377
r-statistics-service/tools/linear_reg.R
Normal file
@@ -0,0 +1,377 @@
|
||||
#' @tool_code ST_LINEAR_REG
|
||||
#' @name 线性回归
|
||||
#' @version 1.0.0
|
||||
#' @description 连续型结局变量的多因素线性回归分析
|
||||
#' @author SSA-Pro Team
|
||||
|
||||
library(glue)
|
||||
library(ggplot2)
|
||||
library(base64enc)
|
||||
|
||||
run_analysis <- function(input) {
|
||||
# ===== 初始化 =====
|
||||
logs <- c()
|
||||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||||
|
||||
on.exit({}, add = TRUE)
|
||||
|
||||
# ===== 数据加载 =====
|
||||
log_add("开始加载输入数据")
|
||||
df <- tryCatch(
|
||||
load_input_data(input),
|
||||
error = function(e) {
|
||||
log_add(paste("数据加载失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(df)) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||||
}
|
||||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||||
|
||||
p <- input$params
|
||||
outcome_var <- p$outcome_var
|
||||
predictors <- p$predictors
|
||||
confounders <- p$confounders
|
||||
|
||||
# ===== 参数校验 =====
|
||||
if (!(outcome_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var))
|
||||
}
|
||||
|
||||
all_vars <- c(predictors, confounders)
|
||||
all_vars <- all_vars[!is.null(all_vars) & all_vars != ""]
|
||||
|
||||
for (v in all_vars) {
|
||||
if (!(v %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v))
|
||||
}
|
||||
}
|
||||
|
||||
if (length(predictors) == 0) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量"))
|
||||
}
|
||||
|
||||
# ===== 数据清洗 =====
|
||||
original_rows <- nrow(df)
|
||||
vars_to_check <- c(outcome_var, all_vars)
|
||||
for (v in vars_to_check) {
|
||||
df <- df[!is.na(df[[v]]), ]
|
||||
}
|
||||
|
||||
# 确保结局变量为数值型
|
||||
if (!is.numeric(df[[outcome_var]])) {
|
||||
df[[outcome_var]] <- as.numeric(as.character(df[[outcome_var]]))
|
||||
df <- df[!is.na(df[[outcome_var]]), ]
|
||||
}
|
||||
|
||||
removed_rows <- original_rows - nrow(df)
|
||||
if (removed_rows > 0) {
|
||||
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
|
||||
}
|
||||
|
||||
n_predictors <- length(all_vars)
|
||||
|
||||
# ===== 护栏检查 =====
|
||||
guardrail_results <- list()
|
||||
warnings_list <- c()
|
||||
|
||||
sample_check <- check_sample_size(nrow(df), min_required = n_predictors + 10, action = ACTION_BLOCK)
|
||||
guardrail_results <- c(guardrail_results, list(sample_check))
|
||||
log_add(glue("样本量: N = {nrow(df)}, 预测变量数 = {n_predictors}, {sample_check$reason}"))
|
||||
|
||||
guardrail_status <- run_guardrail_chain(guardrail_results)
|
||||
|
||||
if (guardrail_status$status == "blocked") {
|
||||
return(list(status = "blocked", message = guardrail_status$reason, trace_log = logs))
|
||||
}
|
||||
|
||||
# ===== 构建模型公式 =====
|
||||
formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + "))
|
||||
formula_obj <- as.formula(formula_str)
|
||||
log_add(glue("模型公式: {formula_str}"))
|
||||
|
||||
# ===== 核心计算 =====
|
||||
log_add("拟合线性回归模型")
|
||||
|
||||
model <- tryCatch({
|
||||
lm(formula_obj, data = df)
|
||||
}, error = function(e) {
|
||||
log_add(paste("模型拟合失败:", e$message))
|
||||
return(NULL)
|
||||
}, warning = function(w) {
|
||||
warnings_list <<- c(warnings_list, w$message)
|
||||
log_add(paste("模型警告:", w$message))
|
||||
invokeRestart("muffleWarning")
|
||||
})
|
||||
|
||||
if (is.null(model)) {
|
||||
return(map_r_error("线性回归模型拟合失败"))
|
||||
}
|
||||
|
||||
model_summary <- summary(model)
|
||||
|
||||
# ===== 提取模型结果 =====
|
||||
coef_summary <- model_summary$coefficients
|
||||
coef_table <- data.frame(
|
||||
variable = rownames(coef_summary),
|
||||
estimate = coef_summary[, "Estimate"],
|
||||
std_error = coef_summary[, "Std. Error"],
|
||||
t_value = coef_summary[, "t value"],
|
||||
p_value = coef_summary[, "Pr(>|t|)"],
|
||||
stringsAsFactors = FALSE
|
||||
)
|
||||
|
||||
# 95% 置信区间
|
||||
ci <- confint(model)
|
||||
coef_table$ci_lower <- ci[, 1]
|
||||
coef_table$ci_upper <- ci[, 2]
|
||||
|
||||
coefficients_list <- lapply(1:nrow(coef_table), function(i) {
|
||||
row <- coef_table[i, ]
|
||||
list(
|
||||
variable = row$variable,
|
||||
estimate = round(row$estimate, 4),
|
||||
std_error = round(row$std_error, 4),
|
||||
t_value = round(row$t_value, 3),
|
||||
ci_lower = round(row$ci_lower, 4),
|
||||
ci_upper = round(row$ci_upper, 4),
|
||||
p_value = round(row$p_value, 4),
|
||||
p_value_fmt = format_p_value(row$p_value),
|
||||
significant = row$p_value < 0.05
|
||||
)
|
||||
})
|
||||
|
||||
# ===== 模型拟合度 =====
|
||||
r_squared <- model_summary$r.squared
|
||||
adj_r_squared <- model_summary$adj.r.squared
|
||||
f_stat <- model_summary$fstatistic
|
||||
f_p_value <- if (!is.null(f_stat)) {
|
||||
pf(f_stat[1], f_stat[2], f_stat[3], lower.tail = FALSE)
|
||||
} else {
|
||||
NA
|
||||
}
|
||||
|
||||
if (!is.null(f_stat)) {
|
||||
log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = {round(f_stat[1], 2)}, P = {round(f_p_value, 4)}"))
|
||||
} else {
|
||||
log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = NA"))
|
||||
}
|
||||
|
||||
# ===== 残差诊断 =====
|
||||
residuals_vals <- residuals(model)
|
||||
fitted_vals <- fitted(model)
|
||||
|
||||
# 残差正态性
|
||||
normality_p <- NA
|
||||
if (length(residuals_vals) >= 3 && length(residuals_vals) <= 5000) {
|
||||
normality_test <- shapiro.test(residuals_vals)
|
||||
normality_p <- normality_test$p.value
|
||||
if (normality_p < 0.05) {
|
||||
warnings_list <- c(warnings_list, glue("残差不满足正态性 (Shapiro-Wilk p = {round(normality_p, 4)})"))
|
||||
}
|
||||
}
|
||||
|
||||
# ===== 共线性检测 (VIF) =====
|
||||
vif_results <- NULL
|
||||
if (length(all_vars) > 1) {
|
||||
tryCatch({
|
||||
if (requireNamespace("car", quietly = TRUE)) {
|
||||
vif_values <- car::vif(model)
|
||||
if (is.matrix(vif_values)) {
|
||||
vif_values <- vif_values[, "GVIF"]
|
||||
}
|
||||
vif_results <- lapply(names(vif_values), function(v) {
|
||||
list(variable = v, vif = round(vif_values[v], 2))
|
||||
})
|
||||
|
||||
high_vif <- names(vif_values)[vif_values > 5]
|
||||
if (length(high_vif) > 0) {
|
||||
warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", ")))
|
||||
}
|
||||
}
|
||||
}, error = function(e) {
|
||||
log_add(paste("VIF 计算失败:", e$message))
|
||||
})
|
||||
}
|
||||
|
||||
# ===== 生成图表 =====
|
||||
log_add("生成诊断图")
|
||||
plot_base64 <- tryCatch({
|
||||
generate_regression_plots(model, outcome_var)
|
||||
}, error = function(e) {
|
||||
log_add(paste("图表生成失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
|
||||
# ===== 生成可复现代码 =====
|
||||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||||
input$original_filename
|
||||
} else {
|
||||
"data.csv"
|
||||
}
|
||||
|
||||
reproducible_code <- glue('
|
||||
# SSA-Pro 自动生成代码
|
||||
# 工具: 线性回归
|
||||
# 时间: {Sys.time()}
|
||||
# ================================
|
||||
|
||||
# 数据准备
|
||||
df <- read.csv("{original_filename}")
|
||||
|
||||
# 线性回归
|
||||
model <- lm({formula_str}, data = df)
|
||||
summary(model)
|
||||
|
||||
# 置信区间
|
||||
confint(model)
|
||||
|
||||
# 残差诊断
|
||||
par(mfrow = c(2, 2))
|
||||
plot(model)
|
||||
|
||||
# VIF(需要 car 包)
|
||||
# library(car)
|
||||
# vif(model)
|
||||
')
|
||||
|
||||
# ===== 构建 report_blocks =====
|
||||
blocks <- list()
|
||||
|
||||
# Block 1: 模型概况
|
||||
kv_model <- list(
|
||||
"模型公式" = formula_str,
|
||||
"观测数" = as.character(nrow(df)),
|
||||
"预测变量数" = as.character(n_predictors),
|
||||
"R²" = as.character(round(r_squared, 4)),
|
||||
"调整 R²" = as.character(round(adj_r_squared, 4))
|
||||
)
|
||||
if (!is.null(f_stat)) {
|
||||
kv_model[["F 统计量"]] <- as.character(round(f_stat[1], 2))
|
||||
kv_model[["模型 P 值"]] <- format_p_value(f_p_value)
|
||||
}
|
||||
if (!is.na(normality_p)) {
|
||||
kv_model[["残差正态性 (Shapiro P)"]] <- format_p_value(normality_p)
|
||||
}
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(kv_model, title = "模型概况")
|
||||
|
||||
# Block 2: 回归系数表
|
||||
coef_headers <- c("变量", "系数 (B)", "标准误", "t 值", "95% CI", "P 值", "显著性")
|
||||
coef_rows <- lapply(coefficients_list, function(row) {
|
||||
ci_str <- sprintf("[%.4f, %.4f]", row$ci_lower, row$ci_upper)
|
||||
sig <- if (row$significant) "*" else ""
|
||||
c(row$variable, as.character(row$estimate), as.character(row$std_error),
|
||||
as.character(row$t_value), ci_str, row$p_value_fmt, sig)
|
||||
})
|
||||
blocks[[length(blocks) + 1]] <- make_table_block(coef_headers, coef_rows,
|
||||
title = "回归系数表", footnote = "* P < 0.05")
|
||||
|
||||
# Block 3: VIF 表
|
||||
if (!is.null(vif_results) && length(vif_results) > 0) {
|
||||
vif_headers <- c("变量", "VIF")
|
||||
vif_rows <- lapply(vif_results, function(row) c(row$variable, as.character(row$vif)))
|
||||
blocks[[length(blocks) + 1]] <- make_table_block(vif_headers, vif_rows, title = "方差膨胀因子 (VIF)")
|
||||
}
|
||||
|
||||
# Block 4: 诊断图
|
||||
if (!is.null(plot_base64)) {
|
||||
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64,
|
||||
title = "回归诊断图", alt = "残差 vs 拟合值 + Q-Q 图")
|
||||
}
|
||||
|
||||
# Block 5: 结论摘要
|
||||
sig_vars <- sapply(coefficients_list, function(r) {
|
||||
if (r$variable != "(Intercept)" && r$significant) r$variable else NULL
|
||||
})
|
||||
sig_vars <- unlist(sig_vars[!sapply(sig_vars, is.null)])
|
||||
|
||||
model_sig <- if (!is.na(f_p_value) && f_p_value < 0.05) "整体具有统计学意义" else "整体不具有统计学意义"
|
||||
f_display <- if (!is.null(f_stat)) round(f_stat[1], 2) else "NA"
|
||||
p_display <- if (!is.na(f_p_value)) format_p_value(f_p_value) else "NA"
|
||||
conclusion <- glue("线性回归模型{model_sig}(F = {f_display},P {p_display})。模型解释了因变量 {round(r_squared * 100, 1)}% 的变异(R² = {round(r_squared, 4)},调整 R² = {round(adj_r_squared, 4)})。")
|
||||
|
||||
if (length(sig_vars) > 0) {
|
||||
conclusion <- paste0(conclusion, glue("\n\n在 α = 0.05 水平下,以下预测变量具有统计学意义:**{paste(sig_vars, collapse = '**、**')}**。"))
|
||||
} else {
|
||||
conclusion <- paste0(conclusion, "\n\n在 α = 0.05 水平下,无预测变量达到统计学意义。")
|
||||
}
|
||||
|
||||
if (length(warnings_list) > 0) {
|
||||
conclusion <- paste0(conclusion, "\n\n⚠️ 注意:", paste(warnings_list, collapse = ";"), "。")
|
||||
}
|
||||
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要")
|
||||
|
||||
# ===== 返回结果 =====
|
||||
log_add("分析完成")
|
||||
|
||||
return(list(
|
||||
status = "success",
|
||||
message = "分析完成",
|
||||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||||
results = list(
|
||||
method = "Multiple Linear Regression (OLS)",
|
||||
formula = formula_str,
|
||||
n_observations = nrow(df),
|
||||
n_predictors = n_predictors,
|
||||
coefficients = coefficients_list,
|
||||
model_fit = list(
|
||||
r_squared = jsonlite::unbox(round(r_squared, 4)),
|
||||
adj_r_squared = jsonlite::unbox(round(adj_r_squared, 4)),
|
||||
f_statistic = if (!is.null(f_stat)) jsonlite::unbox(round(f_stat[1], 2)) else NULL,
|
||||
f_df = if (!is.null(f_stat)) as.numeric(f_stat[2:3]) else NULL,
|
||||
f_p_value = if (!is.na(f_p_value)) jsonlite::unbox(round(f_p_value, 4)) else NULL,
|
||||
f_p_value_fmt = if (!is.na(f_p_value)) format_p_value(f_p_value) else NULL
|
||||
),
|
||||
diagnostics = list(
|
||||
residual_normality_p = if (!is.na(normality_p)) jsonlite::unbox(round(normality_p, 4)) else NULL
|
||||
),
|
||||
vif = vif_results
|
||||
),
|
||||
report_blocks = blocks,
|
||||
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
|
||||
trace_log = logs,
|
||||
reproducible_code = as.character(reproducible_code)
|
||||
))
|
||||
}
|
||||
|
||||
# 辅助函数:回归诊断图(残差 vs 拟合值 + Q-Q 图 拼接)
|
||||
generate_regression_plots <- function(model, outcome_var) {
|
||||
diag_df <- data.frame(
|
||||
fitted = fitted(model),
|
||||
residuals = residuals(model),
|
||||
std_residuals = rstandard(model)
|
||||
)
|
||||
|
||||
# 残差 vs 拟合值
|
||||
p1 <- ggplot(diag_df, aes(x = fitted, y = residuals)) +
|
||||
geom_point(alpha = 0.5, color = "#3b82f6") +
|
||||
geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
|
||||
geom_smooth(method = "loess", se = FALSE, color = "orange", linewidth = 0.8) +
|
||||
theme_minimal() +
|
||||
labs(title = "Residuals vs Fitted", x = "Fitted values", y = "Residuals")
|
||||
|
||||
# Q-Q 图
|
||||
p2 <- ggplot(diag_df, aes(sample = std_residuals)) +
|
||||
stat_qq(alpha = 0.5, color = "#3b82f6") +
|
||||
stat_qq_line(color = "red", linetype = "dashed") +
|
||||
theme_minimal() +
|
||||
labs(title = "Normal Q-Q Plot", x = "Theoretical Quantiles", y = "Standardized Residuals")
|
||||
|
||||
# 拼图
|
||||
if (requireNamespace("gridExtra", quietly = TRUE)) {
|
||||
combined <- gridExtra::arrangeGrob(p1, p2, ncol = 2)
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
ggsave(tmp_file, combined, width = 12, height = 5, dpi = 100)
|
||||
} else {
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
ggsave(tmp_file, p1, width = 7, height = 5, dpi = 100)
|
||||
}
|
||||
|
||||
base64_str <- base64encode(tmp_file)
|
||||
unlink(tmp_file)
|
||||
|
||||
return(paste0("data:image/png;base64,", base64_str))
|
||||
}
|
||||
286
r-statistics-service/tools/wilcoxon.R
Normal file
286
r-statistics-service/tools/wilcoxon.R
Normal file
@@ -0,0 +1,286 @@
|
||||
#' @tool_code ST_WILCOXON
|
||||
#' @name Wilcoxon 符号秩检验
|
||||
#' @version 1.0.0
|
||||
#' @description 配对样本的非参数检验(配对 T 检验的替代方法)
|
||||
#' @author SSA-Pro Team
|
||||
|
||||
library(glue)
|
||||
library(ggplot2)
|
||||
library(base64enc)
|
||||
|
||||
run_analysis <- function(input) {
|
||||
# ===== 初始化 =====
|
||||
logs <- c()
|
||||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||||
|
||||
on.exit({}, add = TRUE)
|
||||
|
||||
# ===== 数据加载 =====
|
||||
log_add("开始加载输入数据")
|
||||
df <- tryCatch(
|
||||
load_input_data(input),
|
||||
error = function(e) {
|
||||
log_add(paste("数据加载失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(df)) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||||
}
|
||||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||||
|
||||
p <- input$params
|
||||
before_var <- p$before_var
|
||||
after_var <- p$after_var
|
||||
|
||||
# ===== 参数校验 =====
|
||||
if (!(before_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = before_var))
|
||||
}
|
||||
if (!(after_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = after_var))
|
||||
}
|
||||
|
||||
# ===== 数据清洗 =====
|
||||
original_rows <- nrow(df)
|
||||
df <- df[!is.na(df[[before_var]]) & !is.na(df[[after_var]]), ]
|
||||
|
||||
# 确保数值型
|
||||
if (!is.numeric(df[[before_var]])) {
|
||||
df[[before_var]] <- as.numeric(as.character(df[[before_var]]))
|
||||
df <- df[!is.na(df[[before_var]]), ]
|
||||
}
|
||||
if (!is.numeric(df[[after_var]])) {
|
||||
df[[after_var]] <- as.numeric(as.character(df[[after_var]]))
|
||||
df <- df[!is.na(df[[after_var]]), ]
|
||||
}
|
||||
|
||||
removed_rows <- original_rows - nrow(df)
|
||||
if (removed_rows > 0) {
|
||||
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
|
||||
}
|
||||
|
||||
# ===== 护栏检查 =====
|
||||
guardrail_results <- list()
|
||||
warnings_list <- c()
|
||||
|
||||
sample_check <- check_sample_size(nrow(df), min_required = 5, action = ACTION_BLOCK)
|
||||
guardrail_results <- c(guardrail_results, list(sample_check))
|
||||
log_add(glue("配对样本量: N = {nrow(df)}, {sample_check$reason}"))
|
||||
|
||||
guardrail_status <- run_guardrail_chain(guardrail_results)
|
||||
|
||||
if (guardrail_status$status == "blocked") {
|
||||
return(list(status = "blocked", message = guardrail_status$reason, trace_log = logs))
|
||||
}
|
||||
|
||||
# ===== 计算差值 =====
|
||||
diff_values <- df[[after_var]] - df[[before_var]]
|
||||
|
||||
# 检查差值方差(容差比较避免浮点精度问题)
|
||||
if (isTRUE(sd(diff_values) < .Machine$double.eps^0.5)) {
|
||||
return(make_error(ERROR_CODES$E007_VARIANCE_ZERO, col = paste(after_var, "-", before_var)))
|
||||
}
|
||||
|
||||
# ===== 核心计算 =====
|
||||
log_add("执行 Wilcoxon 符号秩检验")
|
||||
|
||||
result <- tryCatch({
|
||||
wilcox.test(df[[before_var]], df[[after_var]], paired = TRUE, conf.int = TRUE)
|
||||
}, error = function(e) {
|
||||
log_add(paste("Wilcoxon 检验失败:", e$message))
|
||||
return(NULL)
|
||||
})
|
||||
|
||||
if (is.null(result)) {
|
||||
return(map_r_error("Wilcoxon 符号秩检验计算失败"))
|
||||
}
|
||||
|
||||
method_used <- result$method
|
||||
log_add(glue("V = {result$statistic}, P = {round(result$p.value, 4)}"))
|
||||
|
||||
# ===== 效应量: r = Z / sqrt(N) =====
|
||||
n_pairs <- nrow(df)
|
||||
z_approx <- qnorm(result$p.value / 2)
|
||||
r_effect <- abs(z_approx) / sqrt(n_pairs)
|
||||
|
||||
r_interpretation <- if (r_effect < 0.1) "微小" else if (r_effect < 0.3) "小" else if (r_effect < 0.5) "中等" else "大"
|
||||
|
||||
# ===== 描述统计 =====
|
||||
before_vals <- df[[before_var]]
|
||||
after_vals <- df[[after_var]]
|
||||
|
||||
desc_stats <- list(
|
||||
before = list(
|
||||
variable = before_var,
|
||||
n = length(before_vals),
|
||||
mean = round(mean(before_vals), 3),
|
||||
sd = round(sd(before_vals), 3),
|
||||
median = round(median(before_vals), 3),
|
||||
q1 = round(quantile(before_vals, 0.25), 3),
|
||||
q3 = round(quantile(before_vals, 0.75), 3)
|
||||
),
|
||||
after = list(
|
||||
variable = after_var,
|
||||
n = length(after_vals),
|
||||
mean = round(mean(after_vals), 3),
|
||||
sd = round(sd(after_vals), 3),
|
||||
median = round(median(after_vals), 3),
|
||||
q1 = round(quantile(after_vals, 0.25), 3),
|
||||
q3 = round(quantile(after_vals, 0.75), 3)
|
||||
),
|
||||
difference = list(
|
||||
mean = round(mean(diff_values), 3),
|
||||
sd = round(sd(diff_values), 3),
|
||||
median = round(median(diff_values), 3)
|
||||
)
|
||||
)
|
||||
|
||||
output_results <- list(
|
||||
method = method_used,
|
||||
statistic_V = jsonlite::unbox(as.numeric(result$statistic)),
|
||||
p_value = jsonlite::unbox(as.numeric(result$p.value)),
|
||||
p_value_fmt = format_p_value(result$p.value),
|
||||
pseudomedian = if (!is.null(result$estimate)) jsonlite::unbox(round(as.numeric(result$estimate), 4)) else NULL,
|
||||
conf_int = if (!is.null(result$conf.int)) round(as.numeric(result$conf.int), 4) else NULL,
|
||||
effect_size = list(
|
||||
r = jsonlite::unbox(round(r_effect, 4)),
|
||||
interpretation = r_interpretation
|
||||
),
|
||||
descriptive = desc_stats
|
||||
)
|
||||
|
||||
# ===== 生成图表 =====
|
||||
log_add("生成配对变化图")
|
||||
plot_base64 <- tryCatch({
|
||||
generate_paired_plot(df, before_var, after_var, diff_values)
|
||||
}, error = function(e) {
|
||||
log_add(paste("图表生成失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
|
||||
# ===== 生成可复现代码 =====
|
||||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||||
input$original_filename
|
||||
} else {
|
||||
"data.csv"
|
||||
}
|
||||
|
||||
reproducible_code <- glue('
|
||||
# SSA-Pro 自动生成代码
|
||||
# 工具: Wilcoxon 符号秩检验
|
||||
# 时间: {Sys.time()}
|
||||
# ================================
|
||||
|
||||
library(ggplot2)
|
||||
|
||||
# 数据准备
|
||||
df <- read.csv("{original_filename}")
|
||||
before_var <- "{before_var}"
|
||||
after_var <- "{after_var}"
|
||||
|
||||
# 数据清洗
|
||||
df <- df[!is.na(df[[before_var]]) & !is.na(df[[after_var]]), ]
|
||||
|
||||
# Wilcoxon 符号秩检验
|
||||
result <- wilcox.test(df[[before_var]], df[[after_var]], paired = TRUE, conf.int = TRUE)
|
||||
print(result)
|
||||
|
||||
# 描述统计
|
||||
cat("Before: median =", median(df[[before_var]]), "\\n")
|
||||
cat("After: median =", median(df[[after_var]]), "\\n")
|
||||
cat("Diff: median =", median(df[[after_var]] - df[[before_var]]), "\\n")
|
||||
')
|
||||
|
||||
# ===== 构建 report_blocks =====
|
||||
blocks <- list()
|
||||
|
||||
# Block 1: 描述统计
|
||||
desc_kv <- list()
|
||||
desc_kv[["配对样本量"]] <- as.character(n_pairs)
|
||||
desc_kv[[paste0(before_var, " Median [Q1, Q3]")]] <- as.character(glue("{desc_stats$before$median} [{desc_stats$before$q1}, {desc_stats$before$q3}]"))
|
||||
desc_kv[[paste0(after_var, " Median [Q1, Q3]")]] <- as.character(glue("{desc_stats$after$median} [{desc_stats$after$q1}, {desc_stats$after$q3}]"))
|
||||
desc_kv[["差值 Median"]] <- as.character(desc_stats$difference$median)
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(desc_kv, title = "样本概况")
|
||||
|
||||
# Block 2: 检验结果
|
||||
kv_result <- list(
|
||||
"方法" = method_used,
|
||||
"V 统计量" = as.character(round(as.numeric(result$statistic), 1)),
|
||||
"P 值" = output_results$p_value_fmt,
|
||||
"效应量 r" = as.character(output_results$effect_size$r),
|
||||
"效应量解释" = r_interpretation
|
||||
)
|
||||
if (!is.null(output_results$pseudomedian)) {
|
||||
kv_result[["伪中位数"]] <- as.character(output_results$pseudomedian)
|
||||
}
|
||||
if (!is.null(output_results$conf_int)) {
|
||||
kv_result[["95% 置信区间"]] <- sprintf("[%.4f, %.4f]", output_results$conf_int[1], output_results$conf_int[2])
|
||||
}
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(kv_result, title = "Wilcoxon 符号秩检验结果")
|
||||
|
||||
# Block 3: 图表
|
||||
if (!is.null(plot_base64)) {
|
||||
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64,
|
||||
title = paste("配对变化:", before_var, "→", after_var),
|
||||
alt = "配对样本前后变化图")
|
||||
}
|
||||
|
||||
# Block 4: 结论摘要
|
||||
sig_text <- if (result$p.value < 0.05) "差异具有统计学意义" else "差异无统计学意义"
|
||||
direction <- if (mean(diff_values) > 0) "升高" else "降低"
|
||||
conclusion <- glue(
|
||||
"Wilcoxon 符号秩检验结果:V = {round(as.numeric(result$statistic), 1)},P {output_results$p_value_fmt}。",
|
||||
"配对样本从 **{before_var}** 到 **{after_var}** 的变化{sig_text}",
|
||||
"(中位数{direction} {abs(desc_stats$difference$median)},效应量 r = {output_results$effect_size$r},{r_interpretation}效应)。"
|
||||
)
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要")
|
||||
|
||||
# ===== 返回结果 =====
|
||||
log_add("分析完成")
|
||||
|
||||
return(list(
|
||||
status = "success",
|
||||
message = "分析完成",
|
||||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||||
results = output_results,
|
||||
report_blocks = blocks,
|
||||
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
|
||||
trace_log = logs,
|
||||
reproducible_code = as.character(reproducible_code)
|
||||
))
|
||||
}
|
||||
|
||||
# 辅助函数:配对变化图(差值直方图 + 配对连线图)
|
||||
generate_paired_plot <- function(df, before_var, after_var, diff_values) {
|
||||
# 配对连线图
|
||||
n <- nrow(df)
|
||||
plot_df <- data.frame(
|
||||
id = rep(1:n, 2),
|
||||
time = rep(c("Before", "After"), each = n),
|
||||
value = c(df[[before_var]], df[[after_var]])
|
||||
)
|
||||
plot_df$time <- factor(plot_df$time, levels = c("Before", "After"))
|
||||
|
||||
p <- ggplot(plot_df, aes(x = time, y = value)) +
|
||||
geom_line(aes(group = id), alpha = 0.3, color = "gray60") +
|
||||
geom_point(aes(color = time), size = 2, alpha = 0.6) +
|
||||
stat_summary(fun = median, geom = "point", shape = 18, size = 5, color = "red") +
|
||||
stat_summary(fun = median, geom = "line", aes(group = 1), color = "red", linewidth = 1.2) +
|
||||
theme_minimal() +
|
||||
labs(
|
||||
title = paste("Paired Change:", before_var, "→", after_var),
|
||||
x = "",
|
||||
y = "Value"
|
||||
) +
|
||||
scale_color_manual(values = c("Before" = "#3b82f6", "After" = "#ef4444")) +
|
||||
theme(legend.position = "none")
|
||||
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
ggsave(tmp_file, p, width = 6, height = 5, dpi = 100)
|
||||
base64_str <- base64encode(tmp_file)
|
||||
unlink(tmp_file)
|
||||
|
||||
return(paste0("data:image/png;base64,", base64_str))
|
||||
}
|
||||
@@ -21,7 +21,7 @@ make_markdown_block <- function(content, title = NULL) {
|
||||
#' @param title 可选表格标题
|
||||
#' @param footnote 可选脚注(如方法说明)
|
||||
#' @return block list
|
||||
make_table_block <- function(headers, rows, title = NULL, footnote = NULL) {
|
||||
make_table_block <- function(headers, rows, title = NULL, footnote = NULL, metadata = NULL) {
|
||||
block <- list(
|
||||
type = "table",
|
||||
headers = as.list(headers),
|
||||
@@ -29,6 +29,7 @@ make_table_block <- function(headers, rows, title = NULL, footnote = NULL) {
|
||||
)
|
||||
if (!is.null(title)) block$title <- title
|
||||
if (!is.null(footnote)) block$footnote <- footnote
|
||||
if (!is.null(metadata)) block$metadata <- metadata
|
||||
block
|
||||
}
|
||||
|
||||
|
||||
@@ -228,6 +228,98 @@ run_jit_guardrails <- function(df, tool_code, params) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if (tool_code == "ST_ANOVA_ONE") {
|
||||
group_var <- params$group_var
|
||||
value_var <- params$value_var
|
||||
|
||||
if (!is.null(group_var) && !is.null(value_var)) {
|
||||
groups <- unique(df[[group_var]])
|
||||
|
||||
for (g in groups) {
|
||||
vals <- df[df[[group_var]] == g, value_var]
|
||||
vals <- vals[!is.na(vals)]
|
||||
if (length(vals) >= 3) {
|
||||
norm_result <- check_normality(vals, alpha = 0.05)
|
||||
checks <- c(checks, list(list(
|
||||
check_name = glue("正态性检验 (组: {g})"),
|
||||
passed = norm_result$passed,
|
||||
p_value = norm_result$p_value,
|
||||
recommendation = if (norm_result$passed) "满足正态性" else "建议使用 Kruskal-Wallis"
|
||||
)))
|
||||
|
||||
if (!norm_result$passed) {
|
||||
suggested_tool <- "Kruskal-Wallis (内置于 ST_ANOVA_ONE)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tryCatch({
|
||||
homo_result <- check_homogeneity(df, group_var, value_var, alpha = 0.05)
|
||||
checks <- c(checks, list(list(
|
||||
check_name = "方差齐性检验 (Levene)",
|
||||
passed = homo_result$passed,
|
||||
p_value = homo_result$p_value,
|
||||
recommendation = if (homo_result$passed) "方差齐性满足" else "建议使用 Welch ANOVA"
|
||||
)))
|
||||
}, error = function(e) {
|
||||
message("方差齐性检验失败: ", e$message)
|
||||
})
|
||||
}
|
||||
|
||||
} else if (tool_code == "ST_WILCOXON") {
|
||||
before_var <- params$before_var
|
||||
after_var <- params$after_var
|
||||
|
||||
if (!is.null(before_var) && !is.null(after_var)) {
|
||||
diff_vals <- df[[after_var]] - df[[before_var]]
|
||||
diff_vals <- diff_vals[!is.na(diff_vals)]
|
||||
|
||||
checks <- c(checks, list(list(
|
||||
check_name = "配对样本量检查",
|
||||
passed = length(diff_vals) >= 5,
|
||||
recommendation = if (length(diff_vals) >= 5) "样本量充足" else "配对样本量不足"
|
||||
)))
|
||||
}
|
||||
|
||||
} else if (tool_code %in% c("ST_FISHER", "ST_CHI_SQUARE")) {
|
||||
var1 <- params$var1
|
||||
var2 <- params$var2
|
||||
|
||||
if (!is.null(var1) && !is.null(var2)) {
|
||||
ct <- table(df[[var1]], df[[var2]])
|
||||
expected <- tryCatch(chisq.test(ct)$expected, error = function(e) NULL)
|
||||
|
||||
if (!is.null(expected)) {
|
||||
low_pct <- sum(expected < 5) / length(expected)
|
||||
checks <- c(checks, list(list(
|
||||
check_name = "期望频数检查",
|
||||
passed = low_pct <= 0.2,
|
||||
recommendation = if (low_pct <= 0.2) "期望频数满足卡方检验条件" else "建议使用 Fisher 精确检验"
|
||||
)))
|
||||
|
||||
if (low_pct > 0.2 && tool_code == "ST_CHI_SQUARE") {
|
||||
suggested_tool <- "ST_FISHER"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if (tool_code == "ST_LINEAR_REG") {
|
||||
outcome_var <- params$outcome_var
|
||||
predictors <- params$predictors
|
||||
|
||||
if (!is.null(outcome_var)) {
|
||||
vals <- df[[outcome_var]][!is.na(df[[outcome_var]])]
|
||||
if (length(vals) >= 3) {
|
||||
norm_result <- check_normality(vals, alpha = 0.05)
|
||||
checks <- c(checks, list(list(
|
||||
check_name = glue("结局变量正态性 ({outcome_var})"),
|
||||
passed = norm_result$passed,
|
||||
p_value = norm_result$p_value,
|
||||
recommendation = if (norm_result$passed) "满足正态性" else "结局变量分布偏态,结果需谨慎解读"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 汇总
|
||||
|
||||
Reference in New Issue
Block a user