From 39432db10a8ed833df439c1918acefa4f8a8210f Mon Sep 17 00:00:00 2001 From: nrs <38722970+nrslib@users.noreply.github.com> Date: Mon, 9 Feb 2026 08:10:57 +0900 Subject: [PATCH] takt: override-persona-provider (#171) --- .../engine-persona-providers.test.ts | 204 ++++++++++++++++++ src/__tests__/globalConfig-defaults.test.ts | 57 +++++ src/core/models/global-config.ts | 2 + src/core/models/schemas.ts | 2 + src/core/piece/engine/OptionsBuilder.ts | 2 +- src/core/piece/types.ts | 2 + src/features/tasks/execute/pieceExecution.ts | 1 + src/features/tasks/execute/taskExecution.ts | 1 + src/features/tasks/execute/types.ts | 2 + src/infra/config/global/globalConfig.ts | 4 + 10 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 src/__tests__/engine-persona-providers.test.ts diff --git a/src/__tests__/engine-persona-providers.test.ts b/src/__tests__/engine-persona-providers.test.ts new file mode 100644 index 0000000..afb371c --- /dev/null +++ b/src/__tests__/engine-persona-providers.test.ts @@ -0,0 +1,204 @@ +/** + * Tests for persona_providers config-level provider override. + * + * Verifies the provider resolution priority: + * 1. Movement YAML provider (highest) + * 2. persona_providers[personaDisplayName] + * 3. CLI/global provider (lowest) + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; + +vi.mock('../agents/runner.js', () => ({ + runAgent: vi.fn(), +})); + +vi.mock('../core/piece/evaluation/index.js', () => ({ + detectMatchedRule: vi.fn(), +})); + +vi.mock('../core/piece/phase-runner.js', () => ({ + needsStatusJudgmentPhase: vi.fn(), + runReportPhase: vi.fn(), + runStatusJudgmentPhase: vi.fn(), +})); + +vi.mock('../shared/utils/index.js', async (importOriginal) => ({ + ...(await importOriginal>()), + generateReportDir: vi.fn().mockReturnValue('test-report-dir'), +})); + +import { PieceEngine } from '../core/piece/index.js'; +import { runAgent } from '../agents/runner.js'; +import type { PieceConfig } from '../core/models/index.js'; +import { + makeResponse, + makeRule, + makeMovement, + mockRunAgentSequence, + mockDetectMatchedRuleSequence, + applyDefaultMocks, +} from './engine-test-helpers.js'; + +describe('PieceEngine persona_providers override', () => { + beforeEach(() => { + vi.resetAllMocks(); + applyDefaultMocks(); + }); + + it('should use persona_providers when movement has no provider and persona matches', async () => { + const movement = makeMovement('implement', { + personaDisplayName: 'coder', + rules: [makeRule('done', 'COMPLETE')], + }); + const config: PieceConfig = { + name: 'persona-provider-test', + movements: [movement], + initialMovement: 'implement', + maxIterations: 1, + }; + + mockRunAgentSequence([ + makeResponse({ persona: movement.persona, content: 'done' }), + ]); + mockDetectMatchedRuleSequence([{ index: 0, method: 'phase1_tag' }]); + + const engine = new PieceEngine(config, '/tmp/project', 'test task', { + projectCwd: '/tmp/project', + provider: 'claude', + personaProviders: { coder: 'codex' }, + }); + + await engine.run(); + + const options = vi.mocked(runAgent).mock.calls[0][2]; + expect(options.provider).toBe('codex'); + }); + + it('should use global provider when persona is not in persona_providers', async () => { + const movement = makeMovement('plan', { + personaDisplayName: 'planner', + rules: [makeRule('done', 'COMPLETE')], + }); + const config: PieceConfig = { + name: 'persona-provider-nomatch', + movements: [movement], + initialMovement: 'plan', + maxIterations: 1, + }; + + mockRunAgentSequence([ + makeResponse({ persona: movement.persona, content: 'done' }), + ]); + mockDetectMatchedRuleSequence([{ index: 0, method: 'phase1_tag' }]); + + const engine = new PieceEngine(config, '/tmp/project', 'test task', { + projectCwd: '/tmp/project', + provider: 'claude', + personaProviders: { coder: 'codex' }, + }); + + await engine.run(); + + const options = vi.mocked(runAgent).mock.calls[0][2]; + expect(options.provider).toBe('claude'); + }); + + it('should prioritize movement provider over persona_providers', async () => { + const movement = makeMovement('implement', { + personaDisplayName: 'coder', + provider: 'claude', + rules: [makeRule('done', 'COMPLETE')], + }); + const config: PieceConfig = { + name: 'movement-over-persona', + movements: [movement], + initialMovement: 'implement', + maxIterations: 1, + }; + + mockRunAgentSequence([ + makeResponse({ persona: movement.persona, content: 'done' }), + ]); + mockDetectMatchedRuleSequence([{ index: 0, method: 'phase1_tag' }]); + + const engine = new PieceEngine(config, '/tmp/project', 'test task', { + projectCwd: '/tmp/project', + provider: 'mock', + personaProviders: { coder: 'codex' }, + }); + + await engine.run(); + + const options = vi.mocked(runAgent).mock.calls[0][2]; + expect(options.provider).toBe('claude'); + }); + + it('should work without persona_providers (undefined)', async () => { + const movement = makeMovement('plan', { + personaDisplayName: 'planner', + rules: [makeRule('done', 'COMPLETE')], + }); + const config: PieceConfig = { + name: 'no-persona-providers', + movements: [movement], + initialMovement: 'plan', + maxIterations: 1, + }; + + mockRunAgentSequence([ + makeResponse({ persona: movement.persona, content: 'done' }), + ]); + mockDetectMatchedRuleSequence([{ index: 0, method: 'phase1_tag' }]); + + const engine = new PieceEngine(config, '/tmp/project', 'test task', { + projectCwd: '/tmp/project', + provider: 'claude', + }); + + await engine.run(); + + const options = vi.mocked(runAgent).mock.calls[0][2]; + expect(options.provider).toBe('claude'); + }); + + it('should apply different providers to different personas in a multi-movement piece', async () => { + const planMovement = makeMovement('plan', { + personaDisplayName: 'planner', + rules: [makeRule('done', 'implement')], + }); + const implementMovement = makeMovement('implement', { + personaDisplayName: 'coder', + rules: [makeRule('done', 'COMPLETE')], + }); + const config: PieceConfig = { + name: 'multi-persona-providers', + movements: [planMovement, implementMovement], + initialMovement: 'plan', + maxIterations: 3, + }; + + mockRunAgentSequence([ + makeResponse({ persona: planMovement.persona, content: 'done' }), + makeResponse({ persona: implementMovement.persona, content: 'done' }), + ]); + mockDetectMatchedRuleSequence([ + { index: 0, method: 'phase1_tag' }, + { index: 0, method: 'phase1_tag' }, + ]); + + const engine = new PieceEngine(config, '/tmp/project', 'test task', { + projectCwd: '/tmp/project', + provider: 'claude', + personaProviders: { coder: 'codex' }, + }); + + await engine.run(); + + const calls = vi.mocked(runAgent).mock.calls; + // Plan movement: planner not in persona_providers → claude + expect(calls[0][2].provider).toBe('claude'); + // Implement movement: coder in persona_providers → codex + expect(calls[1][2].provider).toBe('codex'); + }); +}); diff --git a/src/__tests__/globalConfig-defaults.test.ts b/src/__tests__/globalConfig-defaults.test.ts index 484c419..3c62adf 100644 --- a/src/__tests__/globalConfig-defaults.test.ts +++ b/src/__tests__/globalConfig-defaults.test.ts @@ -336,6 +336,63 @@ describe('loadGlobalConfig', () => { expect(config.interactivePreviewMovements).toBe(0); }); + describe('persona_providers', () => { + it('should load persona_providers from config.yaml', () => { + const taktDir = join(testHomeDir, '.takt'); + mkdirSync(taktDir, { recursive: true }); + writeFileSync( + getGlobalConfigPath(), + [ + 'language: en', + 'persona_providers:', + ' coder: codex', + ' reviewer: claude', + ].join('\n'), + 'utf-8', + ); + + const config = loadGlobalConfig(); + + expect(config.personaProviders).toEqual({ + coder: 'codex', + reviewer: 'claude', + }); + }); + + it('should save and reload persona_providers', () => { + const taktDir = join(testHomeDir, '.takt'); + mkdirSync(taktDir, { recursive: true }); + writeFileSync(getGlobalConfigPath(), 'language: en\n', 'utf-8'); + + const config = loadGlobalConfig(); + config.personaProviders = { coder: 'codex' }; + saveGlobalConfig(config); + invalidateGlobalConfigCache(); + + const reloaded = loadGlobalConfig(); + expect(reloaded.personaProviders).toEqual({ coder: 'codex' }); + }); + + it('should have undefined personaProviders by default', () => { + const config = loadGlobalConfig(); + expect(config.personaProviders).toBeUndefined(); + }); + + it('should not save persona_providers when empty', () => { + const taktDir = join(testHomeDir, '.takt'); + mkdirSync(taktDir, { recursive: true }); + writeFileSync(getGlobalConfigPath(), 'language: en\n', 'utf-8'); + + const config = loadGlobalConfig(); + config.personaProviders = {}; + saveGlobalConfig(config); + invalidateGlobalConfigCache(); + + const reloaded = loadGlobalConfig(); + expect(reloaded.personaProviders).toBeUndefined(); + }); + }); + describe('provider/model compatibility validation', () => { it('should throw when provider is codex but model is a Claude alias (opus)', () => { const taktDir = join(testHomeDir, '.takt'); diff --git a/src/core/models/global-config.ts b/src/core/models/global-config.ts index 501d23e..26a7b43 100644 --- a/src/core/models/global-config.ts +++ b/src/core/models/global-config.ts @@ -61,6 +61,8 @@ export interface GlobalConfig { bookmarksFile?: string; /** Path to piece categories file (default: ~/.takt/preferences/piece-categories.yaml) */ pieceCategoriesFile?: string; + /** Per-persona provider overrides (e.g., { coder: 'codex' }) */ + personaProviders?: Record; /** Branch name generation strategy: 'romaji' (fast, default) or 'ai' (slow) */ branchNameStrategy?: 'romaji' | 'ai'; /** Prevent macOS idle sleep during takt execution using caffeinate (default: false) */ diff --git a/src/core/models/schemas.ts b/src/core/models/schemas.ts index b68daf8..9b0dcbe 100644 --- a/src/core/models/schemas.ts +++ b/src/core/models/schemas.ts @@ -318,6 +318,8 @@ export const GlobalConfigSchema = z.object({ bookmarks_file: z.string().optional(), /** Path to piece categories file (default: ~/.takt/preferences/piece-categories.yaml) */ piece_categories_file: z.string().optional(), + /** Per-persona provider overrides (e.g., { coder: 'codex' }) */ + persona_providers: z.record(z.string(), z.enum(['claude', 'codex', 'mock'])).optional(), /** Branch name generation strategy: 'romaji' (fast, default) or 'ai' (slow) */ branch_name_strategy: z.enum(['romaji', 'ai']).optional(), /** Prevent macOS idle sleep during takt execution using caffeinate (default: false) */ diff --git a/src/core/piece/engine/OptionsBuilder.ts b/src/core/piece/engine/OptionsBuilder.ts index ffdcb34..d5866cf 100644 --- a/src/core/piece/engine/OptionsBuilder.ts +++ b/src/core/piece/engine/OptionsBuilder.ts @@ -34,7 +34,7 @@ export class OptionsBuilder { return { cwd: this.getCwd(), personaPath: step.personaPath, - provider: step.provider ?? this.engineOptions.provider, + provider: step.provider ?? this.engineOptions.personaProviders?.[step.personaDisplayName] ?? this.engineOptions.provider, model: step.model ?? this.engineOptions.model, permissionMode: step.permissionMode, language: this.getLanguage(), diff --git a/src/core/piece/types.ts b/src/core/piece/types.ts index af132f1..f72b2b2 100644 --- a/src/core/piece/types.ts +++ b/src/core/piece/types.ts @@ -177,6 +177,8 @@ export interface PieceEngineOptions { language?: Language; provider?: ProviderType; model?: string; + /** Per-persona provider overrides (e.g., { coder: 'codex' }) */ + personaProviders?: Record; /** Enable interactive-only rules and user-input transitions */ interactive?: boolean; /** Rule tag index detector (required for rules evaluation) */ diff --git a/src/features/tasks/execute/pieceExecution.ts b/src/features/tasks/execute/pieceExecution.ts index 481d7fa..d6ada72 100644 --- a/src/features/tasks/execute/pieceExecution.ts +++ b/src/features/tasks/execute/pieceExecution.ts @@ -328,6 +328,7 @@ export async function executePiece( language: options.language, provider: options.provider, model: options.model, + personaProviders: options.personaProviders, interactive: interactiveUserInput, detectRuleIndex, callAiJudge, diff --git a/src/features/tasks/execute/taskExecution.ts b/src/features/tasks/execute/taskExecution.ts index f9dbb6e..460be26 100644 --- a/src/features/tasks/execute/taskExecution.ts +++ b/src/features/tasks/execute/taskExecution.ts @@ -77,6 +77,7 @@ export async function executeTask(options: ExecuteTaskOptions): Promise language: globalConfig.language, provider: agentOverrides?.provider, model: agentOverrides?.model, + personaProviders: globalConfig.personaProviders, interactiveUserInput, interactiveMetadata, startMovement, diff --git a/src/features/tasks/execute/types.ts b/src/features/tasks/execute/types.ts index 9081072..55cf547 100644 --- a/src/features/tasks/execute/types.ts +++ b/src/features/tasks/execute/types.ts @@ -30,6 +30,8 @@ export interface PieceExecutionOptions { language?: Language; provider?: ProviderType; model?: string; + /** Per-persona provider overrides (e.g., { coder: 'codex' }) */ + personaProviders?: Record; /** Enable interactive user input during step transitions */ interactiveUserInput?: boolean; /** Interactive mode result metadata for NDJSON logging */ diff --git a/src/infra/config/global/globalConfig.ts b/src/infra/config/global/globalConfig.ts index 0febec4..32f4280 100644 --- a/src/infra/config/global/globalConfig.ts +++ b/src/infra/config/global/globalConfig.ts @@ -106,6 +106,7 @@ export class GlobalConfigManager { minimalOutput: parsed.minimal_output, bookmarksFile: parsed.bookmarks_file, pieceCategoriesFile: parsed.piece_categories_file, + personaProviders: parsed.persona_providers, branchNameStrategy: parsed.branch_name_strategy, preventSleep: parsed.prevent_sleep, notificationSound: parsed.notification_sound, @@ -172,6 +173,9 @@ export class GlobalConfigManager { if (config.pieceCategoriesFile) { raw.piece_categories_file = config.pieceCategoriesFile; } + if (config.personaProviders && Object.keys(config.personaProviders).length > 0) { + raw.persona_providers = config.personaProviders; + } if (config.branchNameStrategy) { raw.branch_name_strategy = config.branchNameStrategy; }