refactor: provider/model resolution precedenceを一元化
This commit is contained in:
parent
3970b6bcf9
commit
f2ca01ffe0
@ -2,9 +2,14 @@
|
|||||||
* Tests for persona_providers config-level provider/model override.
|
* Tests for persona_providers config-level provider/model override.
|
||||||
*
|
*
|
||||||
* Verifies movement-level provider/model resolution for stepProvider/stepModel:
|
* Verifies movement-level provider/model resolution for stepProvider/stepModel:
|
||||||
* 1. Movement YAML provider (highest)
|
* 1. persona_providers[personaDisplayName].provider (highest)
|
||||||
* 2. persona_providers[personaDisplayName].provider / .model
|
* 2. Movement YAML provider
|
||||||
* 3. CLI provider / model (lowest)
|
* 3. CLI/global provider (lowest in movement resolution)
|
||||||
|
*
|
||||||
|
* Model resolution remains:
|
||||||
|
* 1. persona_providers[personaDisplayName].model
|
||||||
|
* 2. Movement YAML model
|
||||||
|
* 3. CLI/global model
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||||
@ -106,7 +111,7 @@ describe('PieceEngine persona_providers override', () => {
|
|||||||
expect(options.stepProvider).toBe('claude');
|
expect(options.stepProvider).toBe('claude');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should prioritize movement provider over persona_providers', async () => {
|
it('should prioritize persona_providers provider over movement provider', async () => {
|
||||||
const movement = makeMovement('implement', {
|
const movement = makeMovement('implement', {
|
||||||
personaDisplayName: 'coder',
|
personaDisplayName: 'coder',
|
||||||
provider: 'claude',
|
provider: 'claude',
|
||||||
@ -134,7 +139,7 @@ describe('PieceEngine persona_providers override', () => {
|
|||||||
|
|
||||||
const options = vi.mocked(runAgent).mock.calls[0][2];
|
const options = vi.mocked(runAgent).mock.calls[0][2];
|
||||||
expect(options.provider).toBe('mock');
|
expect(options.provider).toBe('mock');
|
||||||
expect(options.stepProvider).toBe('claude');
|
expect(options.stepProvider).toBe('codex');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should work without persona_providers (undefined)', async () => {
|
it('should work without persona_providers (undefined)', async () => {
|
||||||
@ -269,4 +274,36 @@ describe('PieceEngine persona_providers override', () => {
|
|||||||
expect(options.stepProvider).toBe('codex');
|
expect(options.stepProvider).toBe('codex');
|
||||||
expect(options.stepModel).toBe('global-model');
|
expect(options.stepModel).toBe('global-model');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should prioritize persona_providers.model over movement model', async () => {
|
||||||
|
const movement = makeMovement('implement', {
|
||||||
|
personaDisplayName: 'coder',
|
||||||
|
model: 'movement-model',
|
||||||
|
rules: [makeRule('done', 'COMPLETE')],
|
||||||
|
});
|
||||||
|
const config: PieceConfig = {
|
||||||
|
name: 'persona-model-over-movement',
|
||||||
|
movements: [movement],
|
||||||
|
initialMovement: 'implement',
|
||||||
|
maxMovements: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
mockRunAgentSequence([
|
||||||
|
makeResponse({ persona: movement.persona, content: 'done' }),
|
||||||
|
]);
|
||||||
|
mockDetectMatchedRuleSequence([{ index: 0, method: 'phase1_tag' }]);
|
||||||
|
|
||||||
|
const engine = new PieceEngine(config, '/tmp/project', 'test task', {
|
||||||
|
projectCwd: '/tmp/project',
|
||||||
|
provider: 'claude',
|
||||||
|
model: 'global-model',
|
||||||
|
personaProviders: { coder: { provider: 'codex', model: 'persona-model' } },
|
||||||
|
});
|
||||||
|
|
||||||
|
await engine.run();
|
||||||
|
|
||||||
|
const options = vi.mocked(runAgent).mock.calls[0][2];
|
||||||
|
expect(options.stepProvider).toBe('codex');
|
||||||
|
expect(options.stepModel).toBe('persona-model');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,9 +1,32 @@
|
|||||||
import { describe, expect, it } from 'vitest';
|
import { describe, expect, it } from 'vitest';
|
||||||
import { resolveMovementProviderModel } from '../core/piece/provider-resolution.js';
|
import { resolveMovementProviderModel, resolveProviderModelCandidates } from '../core/piece/provider-resolution.js';
|
||||||
|
|
||||||
|
describe('resolveProviderModelCandidates', () => {
|
||||||
|
it('should resolve first defined provider and model independently', () => {
|
||||||
|
const result = resolveProviderModelCandidates([
|
||||||
|
{ provider: undefined, model: 'model-1' },
|
||||||
|
{ provider: 'codex', model: undefined },
|
||||||
|
{ provider: 'claude', model: 'model-2' },
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(result.provider).toBe('codex');
|
||||||
|
expect(result.model).toBe('model-1');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined fields when all candidates are undefined', () => {
|
||||||
|
const result = resolveProviderModelCandidates([
|
||||||
|
{},
|
||||||
|
{ provider: undefined, model: undefined },
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(result.provider).toBeUndefined();
|
||||||
|
expect(result.model).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('resolveMovementProviderModel', () => {
|
describe('resolveMovementProviderModel', () => {
|
||||||
it('should prefer step.provider when step provider is defined', () => {
|
it('should prefer personaProviders.provider over step.provider when both are defined', () => {
|
||||||
// Given: step.provider が指定されている
|
// Given: step.provider と personaProviders.provider が両方指定されている
|
||||||
const result = resolveMovementProviderModel({
|
const result = resolveMovementProviderModel({
|
||||||
step: { provider: 'codex', model: undefined, personaDisplayName: 'coder' },
|
step: { provider: 'codex', model: undefined, personaDisplayName: 'coder' },
|
||||||
provider: 'claude',
|
provider: 'claude',
|
||||||
@ -11,8 +34,8 @@ describe('resolveMovementProviderModel', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// When: provider/model を解決する
|
// When: provider/model を解決する
|
||||||
// Then: step.provider が最優先になる
|
// Then: personaProviders.provider が step.provider を上書きする
|
||||||
expect(result.provider).toBe('codex');
|
expect(result.provider).toBe('opencode');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should use personaProviders.provider when step.provider is undefined', () => {
|
it('should use personaProviders.provider when step.provider is undefined', () => {
|
||||||
@ -54,7 +77,7 @@ describe('resolveMovementProviderModel', () => {
|
|||||||
expect(result.provider).toBeUndefined();
|
expect(result.provider).toBeUndefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should prefer step.model over personaProviders.model and input.model', () => {
|
it('should prefer personaProviders.model over step.model and input.model', () => {
|
||||||
// Given: step.model と personaProviders.model と input.model が指定されている
|
// Given: step.model と personaProviders.model と input.model が指定されている
|
||||||
const result = resolveMovementProviderModel({
|
const result = resolveMovementProviderModel({
|
||||||
step: { provider: undefined, model: 'step-model', personaDisplayName: 'coder' },
|
step: { provider: undefined, model: 'step-model', personaDisplayName: 'coder' },
|
||||||
@ -63,8 +86,8 @@ describe('resolveMovementProviderModel', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// When: provider/model を解決する
|
// When: provider/model を解決する
|
||||||
// Then: step.model が最優先になる
|
// Then: personaProviders.model が step.model を上書きする
|
||||||
expect(result.model).toBe('step-model');
|
expect(result.model).toBe('persona-model');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should use personaProviders.model when step.model is undefined', () => {
|
it('should use personaProviders.model when step.model is undefined', () => {
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import { basename, dirname } from 'node:path';
|
|||||||
import { loadCustomAgents, loadAgentPrompt, resolveConfigValues } from '../infra/config/index.js';
|
import { loadCustomAgents, loadAgentPrompt, resolveConfigValues } from '../infra/config/index.js';
|
||||||
import { getProvider, type ProviderType, type ProviderCallOptions } from '../infra/providers/index.js';
|
import { getProvider, type ProviderType, type ProviderCallOptions } from '../infra/providers/index.js';
|
||||||
import type { AgentResponse, CustomAgentConfig } from '../core/models/index.js';
|
import type { AgentResponse, CustomAgentConfig } from '../core/models/index.js';
|
||||||
|
import { resolveProviderModelCandidates } from '../core/piece/provider-resolution.js';
|
||||||
import { createLogger } from '../shared/utils/index.js';
|
import { createLogger } from '../shared/utils/index.js';
|
||||||
import { loadTemplate } from '../shared/prompts/index.js';
|
import { loadTemplate } from '../shared/prompts/index.js';
|
||||||
import type { RunAgentOptions } from './types.js';
|
import type { RunAgentOptions } from './types.js';
|
||||||
@ -22,40 +23,33 @@ const log = createLogger('runner');
|
|||||||
* delegates execution to the appropriate provider.
|
* delegates execution to the appropriate provider.
|
||||||
*/
|
*/
|
||||||
export class AgentRunner {
|
export class AgentRunner {
|
||||||
/** Resolve provider type from options, agent config, project config, global config */
|
private static resolveProviderAndModel(
|
||||||
private static resolveProvider(
|
|
||||||
cwd: string,
|
cwd: string,
|
||||||
options?: RunAgentOptions,
|
options?: RunAgentOptions,
|
||||||
agentConfig?: CustomAgentConfig,
|
agentConfig?: CustomAgentConfig,
|
||||||
): ProviderType {
|
): { provider: ProviderType; model: string | undefined } {
|
||||||
if (options?.provider) return options.provider;
|
const config = resolveConfigValues(cwd, ['provider', 'model']);
|
||||||
if (options?.stepProvider) return options.stepProvider;
|
const resolvedProvider = resolveProviderModelCandidates([
|
||||||
const config = resolveConfigValues(cwd, ['provider']);
|
{ provider: options?.provider },
|
||||||
if (config.provider) return config.provider;
|
{ provider: options?.stepProvider },
|
||||||
if (agentConfig?.provider) return agentConfig.provider;
|
{ provider: config.provider },
|
||||||
return 'claude';
|
{ provider: agentConfig?.provider },
|
||||||
}
|
]).provider ?? 'claude';
|
||||||
|
|
||||||
/**
|
const configModel = (config.provider ?? 'claude') === resolvedProvider
|
||||||
* Resolve model from options, agent config, global config.
|
? config.model
|
||||||
* Global config model is only used when its provider matches the resolved provider,
|
: undefined;
|
||||||
* preventing cross-provider model mismatches (e.g., 'opus' sent to Codex).
|
const resolvedModel = resolveProviderModelCandidates([
|
||||||
*/
|
{ model: options?.model },
|
||||||
private static resolveModel(
|
{ model: options?.stepModel },
|
||||||
resolvedProvider: ProviderType,
|
{ model: agentConfig?.model },
|
||||||
options?: RunAgentOptions,
|
{ model: configModel },
|
||||||
agentConfig?: CustomAgentConfig,
|
]).model;
|
||||||
): string | undefined {
|
|
||||||
if (options?.model) return options.model;
|
return {
|
||||||
if (options?.stepModel) return options.stepModel;
|
provider: resolvedProvider,
|
||||||
if (agentConfig?.model) return agentConfig.model;
|
model: resolvedModel,
|
||||||
if (!options?.cwd) return undefined;
|
};
|
||||||
const config = resolveConfigValues(options.cwd, ['provider', 'model']);
|
|
||||||
if (config.model) {
|
|
||||||
const defaultProvider = config.provider ?? 'claude';
|
|
||||||
if (defaultProvider === resolvedProvider) return config.model;
|
|
||||||
}
|
|
||||||
return undefined;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Load persona prompt from file path */
|
/** Load persona prompt from file path */
|
||||||
@ -87,7 +81,7 @@ export class AgentRunner {
|
|||||||
|
|
||||||
/** Build ProviderCallOptions from RunAgentOptions */
|
/** Build ProviderCallOptions from RunAgentOptions */
|
||||||
private static buildCallOptions(
|
private static buildCallOptions(
|
||||||
resolvedProvider: ProviderType,
|
resolvedModel: string | undefined,
|
||||||
options: RunAgentOptions,
|
options: RunAgentOptions,
|
||||||
agentConfig?: CustomAgentConfig,
|
agentConfig?: CustomAgentConfig,
|
||||||
): ProviderCallOptions {
|
): ProviderCallOptions {
|
||||||
@ -98,7 +92,7 @@ export class AgentRunner {
|
|||||||
allowedTools: options.allowedTools ?? agentConfig?.allowedTools,
|
allowedTools: options.allowedTools ?? agentConfig?.allowedTools,
|
||||||
mcpServers: options.mcpServers,
|
mcpServers: options.mcpServers,
|
||||||
maxTurns: options.maxTurns,
|
maxTurns: options.maxTurns,
|
||||||
model: AgentRunner.resolveModel(resolvedProvider, options, agentConfig),
|
model: resolvedModel,
|
||||||
permissionMode: options.permissionMode,
|
permissionMode: options.permissionMode,
|
||||||
providerOptions: options.providerOptions,
|
providerOptions: options.providerOptions,
|
||||||
onStream: options.onStream,
|
onStream: options.onStream,
|
||||||
@ -115,7 +109,8 @@ export class AgentRunner {
|
|||||||
task: string,
|
task: string,
|
||||||
options: RunAgentOptions,
|
options: RunAgentOptions,
|
||||||
): Promise<AgentResponse> {
|
): Promise<AgentResponse> {
|
||||||
const providerType = AgentRunner.resolveProvider(options.cwd, options, agentConfig);
|
const resolved = AgentRunner.resolveProviderAndModel(options.cwd, options, agentConfig);
|
||||||
|
const providerType = resolved.provider;
|
||||||
const provider = getProvider(providerType);
|
const provider = getProvider(providerType);
|
||||||
|
|
||||||
const agent = provider.setup({
|
const agent = provider.setup({
|
||||||
@ -127,7 +122,7 @@ export class AgentRunner {
|
|||||||
claudeSkill: agentConfig.claudeSkill,
|
claudeSkill: agentConfig.claudeSkill,
|
||||||
});
|
});
|
||||||
|
|
||||||
return agent.call(task, AgentRunner.buildCallOptions(providerType, options, agentConfig));
|
return agent.call(task, AgentRunner.buildCallOptions(resolved.model, options, agentConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Run an agent by name, path, inline prompt string, or no agent at all */
|
/** Run an agent by name, path, inline prompt string, or no agent at all */
|
||||||
@ -147,9 +142,10 @@ export class AgentRunner {
|
|||||||
permissionMode: options.permissionMode,
|
permissionMode: options.permissionMode,
|
||||||
});
|
});
|
||||||
|
|
||||||
const providerType = AgentRunner.resolveProvider(options.cwd, options);
|
const resolved = AgentRunner.resolveProviderAndModel(options.cwd, options);
|
||||||
|
const providerType = resolved.provider;
|
||||||
const provider = getProvider(providerType);
|
const provider = getProvider(providerType);
|
||||||
const callOptions = AgentRunner.buildCallOptions(providerType, options);
|
const callOptions = AgentRunner.buildCallOptions(resolved.model, options);
|
||||||
|
|
||||||
// 1. If personaPath is provided (resolved file exists), load prompt from file
|
// 1. If personaPath is provided (resolved file exists), load prompt from file
|
||||||
// and wrap it through the perform_agent_system_prompt template
|
// and wrap it through the perform_agent_system_prompt template
|
||||||
|
|||||||
@ -14,12 +14,43 @@ export interface MovementProviderModelOutput {
|
|||||||
model?: string;
|
model?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ProviderModelCandidate {
|
||||||
|
provider?: ProviderType;
|
||||||
|
model?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function resolveProviderModelCandidates(
|
||||||
|
candidates: readonly ProviderModelCandidate[],
|
||||||
|
): MovementProviderModelOutput {
|
||||||
|
let provider: ProviderType | undefined;
|
||||||
|
let model: string | undefined;
|
||||||
|
|
||||||
|
for (const candidate of candidates) {
|
||||||
|
if (provider === undefined && candidate.provider !== undefined) {
|
||||||
|
provider = candidate.provider;
|
||||||
|
}
|
||||||
|
if (model === undefined && candidate.model !== undefined) {
|
||||||
|
model = candidate.model;
|
||||||
|
}
|
||||||
|
if (provider !== undefined && model !== undefined) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { provider, model };
|
||||||
|
}
|
||||||
|
|
||||||
export function resolveMovementProviderModel(input: MovementProviderModelInput): MovementProviderModelOutput {
|
export function resolveMovementProviderModel(input: MovementProviderModelInput): MovementProviderModelOutput {
|
||||||
const personaEntry = input.personaProviders?.[input.step.personaDisplayName];
|
const personaEntry = input.personaProviders?.[input.step.personaDisplayName];
|
||||||
return {
|
const provider = resolveProviderModelCandidates([
|
||||||
provider: input.step.provider
|
{ provider: personaEntry?.provider },
|
||||||
?? personaEntry?.provider
|
{ provider: input.step.provider },
|
||||||
?? input.provider,
|
{ provider: input.provider },
|
||||||
model: input.step.model ?? personaEntry?.model ?? input.model,
|
]).provider;
|
||||||
};
|
const model = resolveProviderModelCandidates([
|
||||||
|
{ model: personaEntry?.model },
|
||||||
|
{ model: input.step.model },
|
||||||
|
{ model: input.model },
|
||||||
|
]).model;
|
||||||
|
return { provider, model };
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user