feat(agent-model): add per-agent model override with default-reset UX and runtime sync (#651)

This commit is contained in:
Felix
2026-03-25 10:13:11 +08:00
committed by GitHub
Unverified
parent 9d40e1fa05
commit ab8fe760ef
16 changed files with 871 additions and 26 deletions

View File

@@ -102,6 +102,101 @@ describe('agent config lifecycle', () => {
);
});
it('exposes effective and override model refs in the snapshot', async () => {
await writeOpenClawJson({
agents: {
defaults: {
model: {
primary: 'moonshot/kimi-k2.5',
},
},
list: [
{ id: 'main', name: 'Main', default: true },
{ id: 'coder', name: 'Coder', model: { primary: 'ark/ark-code-latest' } },
],
},
});
const { listAgentsSnapshot } = await import('@electron/utils/agent-config');
const snapshot = await listAgentsSnapshot();
const main = snapshot.agents.find((agent) => agent.id === 'main');
const coder = snapshot.agents.find((agent) => agent.id === 'coder');
expect(snapshot.defaultModelRef).toBe('moonshot/kimi-k2.5');
expect(main).toMatchObject({
modelRef: 'moonshot/kimi-k2.5',
overrideModelRef: null,
inheritedModel: true,
modelDisplay: 'kimi-k2.5',
});
expect(coder).toMatchObject({
modelRef: 'ark/ark-code-latest',
overrideModelRef: 'ark/ark-code-latest',
inheritedModel: false,
modelDisplay: 'ark-code-latest',
});
});
it('updates and clears per-agent model overrides', async () => {
await writeOpenClawJson({
agents: {
defaults: {
model: {
primary: 'moonshot/kimi-k2.5',
},
},
list: [
{ id: 'main', name: 'Main', default: true },
{ id: 'coder', name: 'Coder' },
],
},
});
const { listAgentsSnapshot, updateAgentModel } = await import('@electron/utils/agent-config');
await updateAgentModel('coder', 'ark/ark-code-latest');
let config = await readOpenClawJson();
let coder = ((config.agents as { list: Array<{ id: string; model?: { primary?: string } }> }).list)
.find((agent) => agent.id === 'coder');
expect(coder?.model?.primary).toBe('ark/ark-code-latest');
let snapshot = await listAgentsSnapshot();
let snapshotCoder = snapshot.agents.find((agent) => agent.id === 'coder');
expect(snapshotCoder).toMatchObject({
modelRef: 'ark/ark-code-latest',
overrideModelRef: 'ark/ark-code-latest',
inheritedModel: false,
});
await updateAgentModel('coder', null);
config = await readOpenClawJson();
coder = ((config.agents as { list: Array<{ id: string; model?: unknown }> }).list)
.find((agent) => agent.id === 'coder');
expect(coder?.model).toBeUndefined();
snapshot = await listAgentsSnapshot();
snapshotCoder = snapshot.agents.find((agent) => agent.id === 'coder');
expect(snapshotCoder).toMatchObject({
modelRef: 'moonshot/kimi-k2.5',
overrideModelRef: null,
inheritedModel: true,
});
});
it('rejects invalid model ref formats when updating agent model', async () => {
await writeOpenClawJson({
agents: {
list: [{ id: 'main', name: 'Main', default: true }],
},
});
const { updateAgentModel } = await import('@electron/utils/agent-config');
await expect(updateAgentModel('main', 'invalid-model-ref')).rejects.toThrow(
'modelRef must be in "provider/model" format',
);
});
it('deletes the config entry, bindings, runtime directory, and managed workspace for a removed agent', async () => {
await writeOpenClawJson({
agents: {

View File

@@ -1,21 +1,31 @@
import React from 'react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { act, render, waitFor } from '@testing-library/react';
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
import { Agents } from '../../src/pages/Agents/index';
const hostApiFetchMock = vi.fn();
const subscribeHostEventMock = vi.fn();
const fetchAgentsMock = vi.fn();
const updateAgentMock = vi.fn();
const updateAgentModelMock = vi.fn();
const refreshProviderSnapshotMock = vi.fn();
const { gatewayState, agentsState } = vi.hoisted(() => ({
const { gatewayState, agentsState, providersState } = vi.hoisted(() => ({
gatewayState: {
status: { state: 'running', port: 18789 },
},
agentsState: {
agents: [] as Array<Record<string, unknown>>,
defaultModelRef: null as string | null,
loading: false,
error: null as string | null,
},
providersState: {
accounts: [] as Array<Record<string, unknown>>,
statuses: [] as Array<Record<string, unknown>>,
vendors: [] as Array<Record<string, unknown>>,
defaultAccountId: '' as string,
},
}));
vi.mock('@/stores/gateway', () => ({
@@ -25,12 +35,16 @@ vi.mock('@/stores/gateway', () => ({
vi.mock('@/stores/agents', () => ({
useAgentsStore: (selector?: (state: typeof agentsState & {
fetchAgents: typeof fetchAgentsMock;
updateAgent: typeof updateAgentMock;
updateAgentModel: typeof updateAgentModelMock;
createAgent: ReturnType<typeof vi.fn>;
deleteAgent: ReturnType<typeof vi.fn>;
}) => unknown) => {
const state = {
...agentsState,
fetchAgents: fetchAgentsMock,
updateAgent: updateAgentMock,
updateAgentModel: updateAgentModelMock,
createAgent: vi.fn(),
deleteAgent: vi.fn(),
};
@@ -38,6 +52,18 @@ vi.mock('@/stores/agents', () => ({
},
}));
vi.mock('@/stores/providers', () => ({
useProviderStore: (selector: (state: typeof providersState & {
refreshProviderSnapshot: typeof refreshProviderSnapshotMock;
}) => unknown) => {
const state = {
...providersState,
refreshProviderSnapshot: refreshProviderSnapshotMock,
};
return selector(state);
},
}));
vi.mock('@/lib/host-api', () => ({
hostApiFetch: (...args: unknown[]) => hostApiFetchMock(...args),
}));
@@ -64,7 +90,16 @@ describe('Agents page status refresh', () => {
beforeEach(() => {
vi.clearAllMocks();
gatewayState.status = { state: 'running', port: 18789 };
agentsState.agents = [];
agentsState.defaultModelRef = null;
providersState.accounts = [];
providersState.statuses = [];
providersState.vendors = [];
providersState.defaultAccountId = '';
fetchAgentsMock.mockResolvedValue(undefined);
updateAgentMock.mockResolvedValue(undefined);
updateAgentModelMock.mockResolvedValue(undefined);
refreshProviderSnapshotMock.mockResolvedValue(undefined);
hostApiFetchMock.mockResolvedValue({
success: true,
channels: [],
@@ -118,4 +153,65 @@ describe('Agents page status refresh', () => {
expect(channelFetchCalls).toHaveLength(2);
});
});
it('uses "Use default model" as form fill only and disables it when already default', async () => {
agentsState.agents = [
{
id: 'main',
name: 'Main',
isDefault: true,
modelDisplay: 'claude-opus-4.6',
modelRef: 'openrouter/anthropic/claude-opus-4.6',
overrideModelRef: null,
inheritedModel: true,
workspace: '~/.openclaw/workspace',
agentDir: '~/.openclaw/agents/main/agent',
mainSessionKey: 'agent:main:desk',
channelTypes: [],
},
];
agentsState.defaultModelRef = 'openrouter/anthropic/claude-opus-4.6';
providersState.accounts = [
{
id: 'openrouter-default',
label: 'OpenRouter',
vendorId: 'openrouter',
authMode: 'api_key',
model: 'openrouter/anthropic/claude-opus-4.6',
enabled: true,
createdAt: '2026-03-24T00:00:00.000Z',
updatedAt: '2026-03-24T00:00:00.000Z',
},
];
providersState.statuses = [{ id: 'openrouter-default', hasKey: true }];
providersState.vendors = [
{ id: 'openrouter', name: 'OpenRouter', modelIdPlaceholder: 'anthropic/claude-opus-4.6' },
];
providersState.defaultAccountId = 'openrouter-default';
render(<Agents />);
await waitFor(() => {
expect(fetchAgentsMock).toHaveBeenCalledTimes(1);
});
fireEvent.click(screen.getByTitle('settings'));
fireEvent.click(screen.getByText('settingsDialog.modelLabel').closest('button') as HTMLButtonElement);
const useDefaultButton = await screen.findByRole('button', { name: 'settingsDialog.useDefaultModel' });
const modelIdInput = screen.getByLabelText('settingsDialog.modelIdLabel');
const saveButton = screen.getByRole('button', { name: 'common:actions.save' });
expect(useDefaultButton).toBeDisabled();
fireEvent.change(modelIdInput, { target: { value: 'anthropic/claude-sonnet-4.5' } });
expect(useDefaultButton).toBeEnabled();
expect(saveButton).toBeEnabled();
fireEvent.click(useDefaultButton);
expect(updateAgentModelMock).not.toHaveBeenCalled();
expect((modelIdInput as HTMLInputElement).value).toBe('anthropic/claude-opus-4.6');
expect(useDefaultButton).toBeDisabled();
});
});

View File

@@ -19,6 +19,8 @@ const mocks = vi.hoisted(() => ({
setOpenClawDefaultModelWithOverride: vi.fn(),
syncProviderConfigToOpenClaw: vi.fn(),
updateAgentModelProvider: vi.fn(),
updateSingleAgentModelProvider: vi.fn(),
listAgentsSnapshot: vi.fn(),
}));
vi.mock('@electron/services/providers/provider-store', () => ({
@@ -50,6 +52,11 @@ vi.mock('@electron/utils/openclaw-auth', () => ({
setOpenClawDefaultModelWithOverride: mocks.setOpenClawDefaultModelWithOverride,
syncProviderConfigToOpenClaw: mocks.syncProviderConfigToOpenClaw,
updateAgentModelProvider: mocks.updateAgentModelProvider,
updateSingleAgentModelProvider: mocks.updateSingleAgentModelProvider,
}));
vi.mock('@electron/utils/agent-config', () => ({
listAgentsSnapshot: mocks.listAgentsSnapshot,
}));
vi.mock('@electron/utils/logger', () => ({
@@ -62,6 +69,7 @@ vi.mock('@electron/utils/logger', () => ({
}));
import {
syncAgentModelOverrideToRuntime,
syncDefaultProviderToRuntime,
syncDeletedProviderToRuntime,
syncSavedProviderToRuntime,
@@ -109,6 +117,8 @@ describe('provider-runtime-sync refresh strategy', () => {
mocks.saveProviderKeyToOpenClaw.mockResolvedValue(undefined);
mocks.removeProviderFromOpenClaw.mockResolvedValue(undefined);
mocks.updateAgentModelProvider.mockResolvedValue(undefined);
mocks.updateSingleAgentModelProvider.mockResolvedValue(undefined);
mocks.listAgentsSnapshot.mockResolvedValue({ agents: [] });
});
it('uses debouncedReload after saving provider config', async () => {
@@ -142,4 +152,48 @@ describe('provider-runtime-sync refresh strategy', () => {
expect(gateway.debouncedReload).not.toHaveBeenCalled();
expect(gateway.debouncedRestart).not.toHaveBeenCalled();
});
it('syncs a targeted agent model override to runtime provider registry', async () => {
mocks.getAllProviders.mockResolvedValue([
createProvider({
id: 'ark',
type: 'ark',
model: 'doubao-pro',
}),
]);
mocks.getProviderConfig.mockImplementation((providerType: string) => {
if (providerType === 'ark') {
return {
api: 'openai-completions',
baseUrl: 'https://ark.cn-beijing.volces.com/api/v3',
apiKeyEnv: 'ARK_API_KEY',
};
}
return {
api: 'openai-completions',
baseUrl: 'https://api.moonshot.cn/v1',
apiKeyEnv: 'MOONSHOT_API_KEY',
};
});
mocks.listAgentsSnapshot.mockResolvedValue({
agents: [
{
id: 'coder',
modelRef: 'ark/ark-code-latest',
},
],
});
await syncAgentModelOverrideToRuntime('coder');
expect(mocks.updateSingleAgentModelProvider).toHaveBeenCalledWith(
'coder',
'ark',
expect.objectContaining({
baseUrl: 'https://ark.cn-beijing.volces.com/api/v3',
api: 'openai-completions',
models: [{ id: 'ark-code-latest', name: 'ark-code-latest' }],
}),
);
});
});