From f2ca01ffe0a972cacf5450e234c14755cf4b3cda Mon Sep 17 00:00:00 2001 From: nrslib <38722970+nrslib@users.noreply.github.com> Date: Mon, 23 Feb 2026 15:28:38 +0900 Subject: [PATCH] =?UTF-8?q?refactor:=20provider/model=20resolution=20prece?= =?UTF-8?q?dence=E3=82=92=E4=B8=80=E5=85=83=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../engine-persona-providers.test.ts | 47 +++++++++++-- src/__tests__/provider-resolution.test.ts | 39 ++++++++--- src/agents/runner.ts | 68 +++++++++---------- src/core/piece/provider-resolution.ts | 43 ++++++++++-- 4 files changed, 142 insertions(+), 55 deletions(-) diff --git a/src/__tests__/engine-persona-providers.test.ts b/src/__tests__/engine-persona-providers.test.ts index fd61047..97b86e0 100644 --- a/src/__tests__/engine-persona-providers.test.ts +++ b/src/__tests__/engine-persona-providers.test.ts @@ -2,9 +2,14 @@ * Tests for persona_providers config-level provider/model override. * * Verifies movement-level provider/model resolution for stepProvider/stepModel: - * 1. Movement YAML provider (highest) - * 2. persona_providers[personaDisplayName].provider / .model - * 3. CLI provider / model (lowest) + * 1. persona_providers[personaDisplayName].provider (highest) + * 2. Movement YAML provider + * 3. CLI/global provider (lowest in movement resolution) + * + * Model resolution remains: + * 1. persona_providers[personaDisplayName].model + * 2. Movement YAML model + * 3. CLI/global model */ import { describe, it, expect, beforeEach, vi } from 'vitest'; @@ -106,7 +111,7 @@ describe('PieceEngine persona_providers override', () => { expect(options.stepProvider).toBe('claude'); }); - it('should prioritize movement provider over persona_providers', async () => { + it('should prioritize persona_providers provider over movement provider', async () => { const movement = makeMovement('implement', { personaDisplayName: 'coder', provider: 'claude', @@ -134,7 +139,7 @@ describe('PieceEngine persona_providers override', () => { const options = vi.mocked(runAgent).mock.calls[0][2]; expect(options.provider).toBe('mock'); - expect(options.stepProvider).toBe('claude'); + expect(options.stepProvider).toBe('codex'); }); it('should work without persona_providers (undefined)', async () => { @@ -269,4 +274,36 @@ describe('PieceEngine persona_providers override', () => { expect(options.stepProvider).toBe('codex'); expect(options.stepModel).toBe('global-model'); }); + + it('should prioritize persona_providers.model over movement model', async () => { + const movement = makeMovement('implement', { + personaDisplayName: 'coder', + model: 'movement-model', + rules: [makeRule('done', 'COMPLETE')], + }); + const config: PieceConfig = { + name: 'persona-model-over-movement', + movements: [movement], + initialMovement: 'implement', + maxMovements: 1, + }; + + mockRunAgentSequence([ + makeResponse({ persona: movement.persona, content: 'done' }), + ]); + mockDetectMatchedRuleSequence([{ index: 0, method: 'phase1_tag' }]); + + const engine = new PieceEngine(config, '/tmp/project', 'test task', { + projectCwd: '/tmp/project', + provider: 'claude', + model: 'global-model', + personaProviders: { coder: { provider: 'codex', model: 'persona-model' } }, + }); + + await engine.run(); + + const options = vi.mocked(runAgent).mock.calls[0][2]; + expect(options.stepProvider).toBe('codex'); + expect(options.stepModel).toBe('persona-model'); + }); }); diff --git a/src/__tests__/provider-resolution.test.ts b/src/__tests__/provider-resolution.test.ts index bec21ac..84c8051 100644 --- a/src/__tests__/provider-resolution.test.ts +++ b/src/__tests__/provider-resolution.test.ts @@ -1,9 +1,32 @@ import { describe, expect, it } from 'vitest'; -import { resolveMovementProviderModel } from '../core/piece/provider-resolution.js'; +import { resolveMovementProviderModel, resolveProviderModelCandidates } from '../core/piece/provider-resolution.js'; + +describe('resolveProviderModelCandidates', () => { + it('should resolve first defined provider and model independently', () => { + const result = resolveProviderModelCandidates([ + { provider: undefined, model: 'model-1' }, + { provider: 'codex', model: undefined }, + { provider: 'claude', model: 'model-2' }, + ]); + + expect(result.provider).toBe('codex'); + expect(result.model).toBe('model-1'); + }); + + it('should return undefined fields when all candidates are undefined', () => { + const result = resolveProviderModelCandidates([ + {}, + { provider: undefined, model: undefined }, + ]); + + expect(result.provider).toBeUndefined(); + expect(result.model).toBeUndefined(); + }); +}); describe('resolveMovementProviderModel', () => { - it('should prefer step.provider when step provider is defined', () => { - // Given: step.provider が指定されている + it('should prefer personaProviders.provider over step.provider when both are defined', () => { + // Given: step.provider と personaProviders.provider が両方指定されている const result = resolveMovementProviderModel({ step: { provider: 'codex', model: undefined, personaDisplayName: 'coder' }, provider: 'claude', @@ -11,8 +34,8 @@ describe('resolveMovementProviderModel', () => { }); // When: provider/model を解決する - // Then: step.provider が最優先になる - expect(result.provider).toBe('codex'); + // Then: personaProviders.provider が step.provider を上書きする + expect(result.provider).toBe('opencode'); }); it('should use personaProviders.provider when step.provider is undefined', () => { @@ -54,7 +77,7 @@ describe('resolveMovementProviderModel', () => { expect(result.provider).toBeUndefined(); }); - it('should prefer step.model over personaProviders.model and input.model', () => { + it('should prefer personaProviders.model over step.model and input.model', () => { // Given: step.model と personaProviders.model と input.model が指定されている const result = resolveMovementProviderModel({ step: { provider: undefined, model: 'step-model', personaDisplayName: 'coder' }, @@ -63,8 +86,8 @@ describe('resolveMovementProviderModel', () => { }); // When: provider/model を解決する - // Then: step.model が最優先になる - expect(result.model).toBe('step-model'); + // Then: personaProviders.model が step.model を上書きする + expect(result.model).toBe('persona-model'); }); it('should use personaProviders.model when step.model is undefined', () => { diff --git a/src/agents/runner.ts b/src/agents/runner.ts index 43c2e32..c62f2c7 100644 --- a/src/agents/runner.ts +++ b/src/agents/runner.ts @@ -7,6 +7,7 @@ import { basename, dirname } from 'node:path'; import { loadCustomAgents, loadAgentPrompt, resolveConfigValues } from '../infra/config/index.js'; import { getProvider, type ProviderType, type ProviderCallOptions } from '../infra/providers/index.js'; import type { AgentResponse, CustomAgentConfig } from '../core/models/index.js'; +import { resolveProviderModelCandidates } from '../core/piece/provider-resolution.js'; import { createLogger } from '../shared/utils/index.js'; import { loadTemplate } from '../shared/prompts/index.js'; import type { RunAgentOptions } from './types.js'; @@ -22,40 +23,33 @@ const log = createLogger('runner'); * delegates execution to the appropriate provider. */ export class AgentRunner { - /** Resolve provider type from options, agent config, project config, global config */ - private static resolveProvider( + private static resolveProviderAndModel( cwd: string, options?: RunAgentOptions, agentConfig?: CustomAgentConfig, - ): ProviderType { - if (options?.provider) return options.provider; - if (options?.stepProvider) return options.stepProvider; - const config = resolveConfigValues(cwd, ['provider']); - if (config.provider) return config.provider; - if (agentConfig?.provider) return agentConfig.provider; - return 'claude'; - } + ): { provider: ProviderType; model: string | undefined } { + const config = resolveConfigValues(cwd, ['provider', 'model']); + const resolvedProvider = resolveProviderModelCandidates([ + { provider: options?.provider }, + { provider: options?.stepProvider }, + { provider: config.provider }, + { provider: agentConfig?.provider }, + ]).provider ?? 'claude'; - /** - * Resolve model from options, agent config, global config. - * Global config model is only used when its provider matches the resolved provider, - * preventing cross-provider model mismatches (e.g., 'opus' sent to Codex). - */ - private static resolveModel( - resolvedProvider: ProviderType, - options?: RunAgentOptions, - agentConfig?: CustomAgentConfig, - ): string | undefined { - if (options?.model) return options.model; - if (options?.stepModel) return options.stepModel; - if (agentConfig?.model) return agentConfig.model; - if (!options?.cwd) return undefined; - const config = resolveConfigValues(options.cwd, ['provider', 'model']); - if (config.model) { - const defaultProvider = config.provider ?? 'claude'; - if (defaultProvider === resolvedProvider) return config.model; - } - return undefined; + const configModel = (config.provider ?? 'claude') === resolvedProvider + ? config.model + : undefined; + const resolvedModel = resolveProviderModelCandidates([ + { model: options?.model }, + { model: options?.stepModel }, + { model: agentConfig?.model }, + { model: configModel }, + ]).model; + + return { + provider: resolvedProvider, + model: resolvedModel, + }; } /** Load persona prompt from file path */ @@ -87,7 +81,7 @@ export class AgentRunner { /** Build ProviderCallOptions from RunAgentOptions */ private static buildCallOptions( - resolvedProvider: ProviderType, + resolvedModel: string | undefined, options: RunAgentOptions, agentConfig?: CustomAgentConfig, ): ProviderCallOptions { @@ -98,7 +92,7 @@ export class AgentRunner { allowedTools: options.allowedTools ?? agentConfig?.allowedTools, mcpServers: options.mcpServers, maxTurns: options.maxTurns, - model: AgentRunner.resolveModel(resolvedProvider, options, agentConfig), + model: resolvedModel, permissionMode: options.permissionMode, providerOptions: options.providerOptions, onStream: options.onStream, @@ -115,7 +109,8 @@ export class AgentRunner { task: string, options: RunAgentOptions, ): Promise { - const providerType = AgentRunner.resolveProvider(options.cwd, options, agentConfig); + const resolved = AgentRunner.resolveProviderAndModel(options.cwd, options, agentConfig); + const providerType = resolved.provider; const provider = getProvider(providerType); const agent = provider.setup({ @@ -127,7 +122,7 @@ export class AgentRunner { claudeSkill: agentConfig.claudeSkill, }); - return agent.call(task, AgentRunner.buildCallOptions(providerType, options, agentConfig)); + return agent.call(task, AgentRunner.buildCallOptions(resolved.model, options, agentConfig)); } /** Run an agent by name, path, inline prompt string, or no agent at all */ @@ -147,9 +142,10 @@ export class AgentRunner { permissionMode: options.permissionMode, }); - const providerType = AgentRunner.resolveProvider(options.cwd, options); + const resolved = AgentRunner.resolveProviderAndModel(options.cwd, options); + const providerType = resolved.provider; const provider = getProvider(providerType); - const callOptions = AgentRunner.buildCallOptions(providerType, options); + const callOptions = AgentRunner.buildCallOptions(resolved.model, options); // 1. If personaPath is provided (resolved file exists), load prompt from file // and wrap it through the perform_agent_system_prompt template diff --git a/src/core/piece/provider-resolution.ts b/src/core/piece/provider-resolution.ts index 5364b10..60ee330 100644 --- a/src/core/piece/provider-resolution.ts +++ b/src/core/piece/provider-resolution.ts @@ -14,12 +14,43 @@ export interface MovementProviderModelOutput { model?: string; } +export interface ProviderModelCandidate { + provider?: ProviderType; + model?: string; +} + +export function resolveProviderModelCandidates( + candidates: readonly ProviderModelCandidate[], +): MovementProviderModelOutput { + let provider: ProviderType | undefined; + let model: string | undefined; + + for (const candidate of candidates) { + if (provider === undefined && candidate.provider !== undefined) { + provider = candidate.provider; + } + if (model === undefined && candidate.model !== undefined) { + model = candidate.model; + } + if (provider !== undefined && model !== undefined) { + break; + } + } + + return { provider, model }; +} + export function resolveMovementProviderModel(input: MovementProviderModelInput): MovementProviderModelOutput { const personaEntry = input.personaProviders?.[input.step.personaDisplayName]; - return { - provider: input.step.provider - ?? personaEntry?.provider - ?? input.provider, - model: input.step.model ?? personaEntry?.model ?? input.model, - }; + const provider = resolveProviderModelCandidates([ + { provider: personaEntry?.provider }, + { provider: input.step.provider }, + { provider: input.provider }, + ]).provider; + const model = resolveProviderModelCandidates([ + { model: personaEntry?.model }, + { model: input.step.model }, + { model: input.model }, + ]).model; + return { provider, model }; }