feat(agent-model): add per-agent model override with default-reset UX and runtime sync (#651)

This commit is contained in:
Felix
2026-03-25 10:13:11 +08:00
committed by GitHub
Unverified
parent 9d40e1fa05
commit ab8fe760ef
16 changed files with 871 additions and 26 deletions

View File

@@ -7,10 +7,11 @@ import {
listAgentsSnapshot,
removeAgentWorkspaceDirectory,
resolveAccountIdForAgent,
updateAgentModel,
updateAgentName,
} from '../../utils/agent-config';
import { deleteChannelAccountConfig } from '../../utils/channel-config';
import { syncAllProviderAuthToRuntime } from '../../services/providers/provider-runtime-sync';
import { syncAgentModelOverrideToRuntime, syncAllProviderAuthToRuntime } from '../../services/providers/provider-runtime-sync';
import type { HostApiContext } from '../context';
import { parseJsonBody, sendJson } from '../route-utils';
@@ -151,6 +152,26 @@ export async function handleAgentRoutes(
return true;
}
if (parts.length === 2 && parts[1] === 'model') {
try {
const body = await parseJsonBody<{ modelRef?: string | null }>(req);
const agentId = decodeURIComponent(parts[0]);
const snapshot = await updateAgentModel(agentId, body.modelRef ?? null);
try {
await syncAllProviderAuthToRuntime();
// Ensure this agent's runtime model registry reflects the new model override.
await syncAgentModelOverrideToRuntime(agentId);
} catch (syncError) {
console.warn('[agents] Failed to sync runtime after updating agent model:', syncError);
}
scheduleGatewayReload(ctx, 'update-agent-model');
sendJson(res, 200, { success: true, ...snapshot });
} catch (error) {
sendJson(res, 500, { success: false, error: String(error) });
}
return true;
}
if (parts.length === 3 && parts[1] === 'channels') {
try {
const agentId = decodeURIComponent(parts[0]);

View File

@@ -12,8 +12,10 @@ import {
setOpenClawDefaultModelWithOverride,
syncProviderConfigToOpenClaw,
updateAgentModelProvider,
updateSingleAgentModelProvider,
} from '../../utils/openclaw-auth';
import { logger } from '../../utils/logger';
import { listAgentsSnapshot } from '../../utils/agent-config';
const GOOGLE_OAUTH_RUNTIME_PROVIDER = 'google-gemini-cli';
const GOOGLE_OAUTH_DEFAULT_MODEL_REF = `${GOOGLE_OAUTH_RUNTIME_PROVIDER}/gemini-3-pro-preview`;
@@ -336,6 +338,116 @@ async function syncProviderToRuntime(
return context;
}
function parseModelRef(modelRef: string): { providerKey: string; modelId: string } | null {
const trimmed = modelRef.trim();
const separatorIndex = trimmed.indexOf('/');
if (separatorIndex <= 0 || separatorIndex >= trimmed.length - 1) {
return null;
}
return {
providerKey: trimmed.slice(0, separatorIndex),
modelId: trimmed.slice(separatorIndex + 1),
};
}
async function buildRuntimeProviderConfigMap(): Promise<Map<string, ProviderConfig>> {
const configs = await getAllProviders();
const runtimeMap = new Map<string, ProviderConfig>();
for (const config of configs) {
const runtimeKey = await resolveRuntimeProviderKey(config);
runtimeMap.set(runtimeKey, config);
}
return runtimeMap;
}
async function buildAgentModelProviderEntry(
config: ProviderConfig,
modelId: string,
): Promise<{
baseUrl?: string;
api?: string;
models?: Array<{ id: string; name: string }>;
apiKey?: string;
authHeader?: boolean;
} | null> {
const meta = getProviderConfig(config.type);
const api = config.apiProtocol || (config.type === 'custom' ? 'openai-completions' : meta?.api);
const baseUrl = normalizeProviderBaseUrl(config, config.baseUrl || meta?.baseUrl, api);
if (!api || !baseUrl) {
return null;
}
let apiKey: string | undefined;
let authHeader: boolean | undefined;
if (config.type === 'custom') {
apiKey = (await getApiKey(config.id)) || undefined;
} else if (config.type === 'minimax-portal' || config.type === 'minimax-portal-cn') {
const accountApiKey = await getApiKey(config.id);
if (accountApiKey) {
apiKey = accountApiKey;
} else {
authHeader = true;
apiKey = 'minimax-oauth';
}
} else if (config.type === 'qwen-portal') {
const accountApiKey = await getApiKey(config.id);
if (accountApiKey) {
apiKey = accountApiKey;
} else {
apiKey = 'qwen-oauth';
}
}
return {
baseUrl,
api,
models: [{ id: modelId, name: modelId }],
apiKey,
authHeader,
};
}
async function syncAgentModelsToRuntime(agentIds?: Set<string>): Promise<void> {
const snapshot = await listAgentsSnapshot();
const runtimeProviderConfigs = await buildRuntimeProviderConfigMap();
const targets = snapshot.agents.filter((agent) => {
if (!agent.modelRef) return false;
if (!agentIds) return true;
return agentIds.has(agent.id);
});
for (const agent of targets) {
const parsed = parseModelRef(agent.modelRef || '');
if (!parsed) {
continue;
}
const providerConfig = runtimeProviderConfigs.get(parsed.providerKey);
if (!providerConfig) {
logger.warn(
`[provider-runtime] No provider account mapped to runtime key "${parsed.providerKey}" for agent "${agent.id}"`,
);
continue;
}
const entry = await buildAgentModelProviderEntry(providerConfig, parsed.modelId);
if (!entry) {
continue;
}
await updateSingleAgentModelProvider(agent.id, parsed.providerKey, entry);
}
}
export async function syncAgentModelOverrideToRuntime(agentId: string): Promise<void> {
await syncAgentModelsToRuntime(new Set([agentId]));
}
export async function syncSavedProviderToRuntime(
config: ProviderConfig,
apiKey: string | undefined,
@@ -346,6 +458,12 @@ export async function syncSavedProviderToRuntime(
return;
}
try {
await syncAgentModelsToRuntime();
} catch (err) {
logger.warn('[provider-runtime] Failed to sync per-agent model registries after provider save:', err);
}
scheduleGatewayRefresh(
gatewayManager,
`Scheduling Gateway reload after saving provider "${context.runtimeProviderKey}" config`,
@@ -388,6 +506,12 @@ export async function syncUpdatedProviderToRuntime(
}
}
try {
await syncAgentModelsToRuntime();
} catch (err) {
logger.warn('[provider-runtime] Failed to sync per-agent model registries after provider update:', err);
}
scheduleGatewayRefresh(
gatewayManager,
`Scheduling Gateway reload after updating provider "${ock}" config`,
@@ -496,6 +620,11 @@ export async function syncDefaultProviderToRuntime(
await setOpenClawDefaultModel(browserOAuthRuntimeProvider, modelOverride, fallbackModels);
logger.info(`Configured openclaw.json for browser OAuth provider "${provider.id}"`);
try {
await syncAgentModelsToRuntime();
} catch (err) {
logger.warn('[provider-runtime] Failed to sync per-agent model registries after browser OAuth switch:', err);
}
scheduleGatewayRefresh(
gatewayManager,
`Scheduling Gateway reload after provider switch to "${browserOAuthRuntimeProvider}"`,
@@ -557,6 +686,12 @@ export async function syncDefaultProviderToRuntime(
});
}
try {
await syncAgentModelsToRuntime();
} catch (err) {
logger.warn('[provider-runtime] Failed to sync per-agent model registries after default provider switch:', err);
}
scheduleGatewayRefresh(
gatewayManager,
`Scheduling Gateway reload after provider switch to "${ock}"`,

View File

@@ -81,6 +81,8 @@ export interface AgentSummary {
name: string;
isDefault: boolean;
modelDisplay: string;
modelRef: string | null;
overrideModelRef: string | null;
inheritedModel: boolean;
workspace: string;
agentDir: string;
@@ -91,29 +93,38 @@ export interface AgentSummary {
export interface AgentsSnapshot {
agents: AgentSummary[];
defaultAgentId: string;
defaultModelRef: string | null;
configuredChannelTypes: string[];
channelOwners: Record<string, string>;
channelAccountOwners: Record<string, string>;
}
function formatModelLabel(model: unknown): string | null {
function resolveModelRef(model: unknown): string | null {
if (typeof model === 'string' && model.trim()) {
const trimmed = model.trim();
const parts = trimmed.split('/');
return parts[parts.length - 1] || trimmed;
return model.trim();
}
if (model && typeof model === 'object') {
const primary = (model as AgentModelConfig).primary;
if (typeof primary === 'string' && primary.trim()) {
const parts = primary.trim().split('/');
return parts[parts.length - 1] || primary.trim();
return primary.trim();
}
}
return null;
}
function formatModelLabel(model: unknown): string | null {
const modelRef = resolveModelRef(model);
if (modelRef) {
const trimmed = modelRef;
const parts = trimmed.split('/');
return parts[parts.length - 1] || trimmed;
}
return null;
}
function normalizeAgentName(name: string): string {
return name.trim() || 'Agent';
}
@@ -487,10 +498,13 @@ async function buildSnapshotFromConfig(config: AgentConfigDocument): Promise<Age
channelOwners[channelType] = primaryOwner;
}
const defaultModelLabel = formatModelLabel((config.agents as AgentsConfig | undefined)?.defaults?.model);
const defaultModelConfig = (config.agents as AgentsConfig | undefined)?.defaults?.model;
const defaultModelLabel = formatModelLabel(defaultModelConfig);
const defaultModelRef = resolveModelRef(defaultModelConfig);
const agents: AgentSummary[] = entries.map((entry) => {
const explicitModelRef = resolveModelRef(entry.model);
const modelLabel = formatModelLabel(entry.model) || defaultModelLabel || 'Not configured';
const inheritedModel = !formatModelLabel(entry.model) && Boolean(defaultModelLabel);
const inheritedModel = !explicitModelRef && Boolean(defaultModelLabel);
const entryIdNorm = normalizeAgentIdForBinding(entry.id);
const ownedChannels = agentChannelSets.get(entryIdNorm) ?? new Set<string>();
return {
@@ -498,6 +512,8 @@ async function buildSnapshotFromConfig(config: AgentConfigDocument): Promise<Age
name: entry.name || (entry.id === MAIN_AGENT_ID ? MAIN_AGENT_NAME : entry.id),
isDefault: entry.id === defaultAgentId,
modelDisplay: modelLabel,
modelRef: explicitModelRef || defaultModelRef || null,
overrideModelRef: explicitModelRef,
inheritedModel,
workspace: entry.workspace || (entry.id === MAIN_AGENT_ID ? getDefaultWorkspacePath(config) : `~/.openclaw/workspace-${entry.id}`),
agentDir: entry.agentDir || getDefaultAgentDirPath(entry.id),
@@ -511,6 +527,7 @@ async function buildSnapshotFromConfig(config: AgentConfigDocument): Promise<Age
return {
agents,
defaultAgentId,
defaultModelRef,
configuredChannelTypes: configuredChannels.map((channelType) => toUiChannelType(channelType)),
channelOwners,
channelAccountOwners,
@@ -598,6 +615,44 @@ export async function updateAgentName(agentId: string, name: string): Promise<Ag
});
}
function isValidModelRef(modelRef: string): boolean {
const firstSlash = modelRef.indexOf('/');
return firstSlash > 0 && firstSlash < modelRef.length - 1;
}
export async function updateAgentModel(agentId: string, modelRef: string | null): Promise<AgentsSnapshot> {
return withConfigLock(async () => {
const config = await readOpenClawConfig() as AgentConfigDocument;
const { agentsConfig, entries } = normalizeAgentsConfig(config);
const index = entries.findIndex((entry) => entry.id === agentId);
if (index === -1) {
throw new Error(`Agent "${agentId}" not found`);
}
const normalizedModelRef = typeof modelRef === 'string' ? modelRef.trim() : '';
const nextEntry: AgentListEntry = { ...entries[index] };
if (!normalizedModelRef) {
delete nextEntry.model;
} else {
if (!isValidModelRef(normalizedModelRef)) {
throw new Error('modelRef must be in "provider/model" format');
}
nextEntry.model = { primary: normalizedModelRef };
}
entries[index] = nextEntry;
config.agents = {
...agentsConfig,
list: entries,
};
await writeOpenClawConfig(config);
logger.info('Updated agent model', { agentId, modelRef: normalizedModelRef || null });
return buildSnapshotFromConfig(config);
});
}
export async function deleteAgentConfig(agentId: string): Promise<{ snapshot: AgentsSnapshot; removedEntry: AgentListEntry }> {
return withConfigLock(async () => {
if (agentId === MAIN_AGENT_ID) {

View File

@@ -918,18 +918,20 @@ export async function syncSessionIdleMinutesToOpenClaw(): Promise<void> {
/**
* Update a provider entry in every discovered agent's models.json.
*/
export async function updateAgentModelProvider(
type AgentModelProviderEntry = {
baseUrl?: string;
api?: string;
models?: Array<{ id: string; name: string }>;
apiKey?: string;
/** When true, pi-ai sends Authorization: Bearer instead of x-api-key */
authHeader?: boolean;
};
async function updateModelsJsonProviderEntriesForAgents(
agentIds: string[],
providerType: string,
entry: {
baseUrl?: string;
api?: string;
models?: Array<{ id: string; name: string }>;
apiKey?: string;
/** When true, pi-ai sends Authorization: Bearer instead of x-api-key */
authHeader?: boolean;
}
entry: AgentModelProviderEntry,
): Promise<void> {
const agentIds = await discoverAgentIds();
for (const agentId of agentIds) {
const modelsPath = join(homedir(), '.openclaw', 'agents', agentId, 'agent', 'models.json');
let data: Record<string, unknown> = {};
@@ -975,6 +977,26 @@ export async function updateAgentModelProvider(
}
}
export async function updateAgentModelProvider(
providerType: string,
entry: AgentModelProviderEntry,
): Promise<void> {
const agentIds = await discoverAgentIds();
await updateModelsJsonProviderEntriesForAgents(agentIds, providerType, entry);
}
export async function updateSingleAgentModelProvider(
agentId: string,
providerType: string,
entry: AgentModelProviderEntry,
): Promise<void> {
const normalizedAgentId = agentId.trim();
if (!normalizedAgentId) {
throw new Error('agentId is required');
}
await updateModelsJsonProviderEntriesForAgents([normalizedAgentId], providerType, entry);
}
/**
* Sanitize ~/.openclaw/openclaw.json before Gateway start.
*