feat(agent-model): add per-agent model override with default-reset UX and runtime sync (#651)
This commit is contained in:
@@ -7,10 +7,11 @@ import {
|
||||
listAgentsSnapshot,
|
||||
removeAgentWorkspaceDirectory,
|
||||
resolveAccountIdForAgent,
|
||||
updateAgentModel,
|
||||
updateAgentName,
|
||||
} from '../../utils/agent-config';
|
||||
import { deleteChannelAccountConfig } from '../../utils/channel-config';
|
||||
import { syncAllProviderAuthToRuntime } from '../../services/providers/provider-runtime-sync';
|
||||
import { syncAgentModelOverrideToRuntime, syncAllProviderAuthToRuntime } from '../../services/providers/provider-runtime-sync';
|
||||
import type { HostApiContext } from '../context';
|
||||
import { parseJsonBody, sendJson } from '../route-utils';
|
||||
|
||||
@@ -151,6 +152,26 @@ export async function handleAgentRoutes(
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parts.length === 2 && parts[1] === 'model') {
|
||||
try {
|
||||
const body = await parseJsonBody<{ modelRef?: string | null }>(req);
|
||||
const agentId = decodeURIComponent(parts[0]);
|
||||
const snapshot = await updateAgentModel(agentId, body.modelRef ?? null);
|
||||
try {
|
||||
await syncAllProviderAuthToRuntime();
|
||||
// Ensure this agent's runtime model registry reflects the new model override.
|
||||
await syncAgentModelOverrideToRuntime(agentId);
|
||||
} catch (syncError) {
|
||||
console.warn('[agents] Failed to sync runtime after updating agent model:', syncError);
|
||||
}
|
||||
scheduleGatewayReload(ctx, 'update-agent-model');
|
||||
sendJson(res, 200, { success: true, ...snapshot });
|
||||
} catch (error) {
|
||||
sendJson(res, 500, { success: false, error: String(error) });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parts.length === 3 && parts[1] === 'channels') {
|
||||
try {
|
||||
const agentId = decodeURIComponent(parts[0]);
|
||||
|
||||
@@ -12,8 +12,10 @@ import {
|
||||
setOpenClawDefaultModelWithOverride,
|
||||
syncProviderConfigToOpenClaw,
|
||||
updateAgentModelProvider,
|
||||
updateSingleAgentModelProvider,
|
||||
} from '../../utils/openclaw-auth';
|
||||
import { logger } from '../../utils/logger';
|
||||
import { listAgentsSnapshot } from '../../utils/agent-config';
|
||||
|
||||
const GOOGLE_OAUTH_RUNTIME_PROVIDER = 'google-gemini-cli';
|
||||
const GOOGLE_OAUTH_DEFAULT_MODEL_REF = `${GOOGLE_OAUTH_RUNTIME_PROVIDER}/gemini-3-pro-preview`;
|
||||
@@ -336,6 +338,116 @@ async function syncProviderToRuntime(
|
||||
return context;
|
||||
}
|
||||
|
||||
function parseModelRef(modelRef: string): { providerKey: string; modelId: string } | null {
|
||||
const trimmed = modelRef.trim();
|
||||
const separatorIndex = trimmed.indexOf('/');
|
||||
if (separatorIndex <= 0 || separatorIndex >= trimmed.length - 1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
providerKey: trimmed.slice(0, separatorIndex),
|
||||
modelId: trimmed.slice(separatorIndex + 1),
|
||||
};
|
||||
}
|
||||
|
||||
async function buildRuntimeProviderConfigMap(): Promise<Map<string, ProviderConfig>> {
|
||||
const configs = await getAllProviders();
|
||||
const runtimeMap = new Map<string, ProviderConfig>();
|
||||
|
||||
for (const config of configs) {
|
||||
const runtimeKey = await resolveRuntimeProviderKey(config);
|
||||
runtimeMap.set(runtimeKey, config);
|
||||
}
|
||||
|
||||
return runtimeMap;
|
||||
}
|
||||
|
||||
async function buildAgentModelProviderEntry(
|
||||
config: ProviderConfig,
|
||||
modelId: string,
|
||||
): Promise<{
|
||||
baseUrl?: string;
|
||||
api?: string;
|
||||
models?: Array<{ id: string; name: string }>;
|
||||
apiKey?: string;
|
||||
authHeader?: boolean;
|
||||
} | null> {
|
||||
const meta = getProviderConfig(config.type);
|
||||
const api = config.apiProtocol || (config.type === 'custom' ? 'openai-completions' : meta?.api);
|
||||
const baseUrl = normalizeProviderBaseUrl(config, config.baseUrl || meta?.baseUrl, api);
|
||||
if (!api || !baseUrl) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let apiKey: string | undefined;
|
||||
let authHeader: boolean | undefined;
|
||||
|
||||
if (config.type === 'custom') {
|
||||
apiKey = (await getApiKey(config.id)) || undefined;
|
||||
} else if (config.type === 'minimax-portal' || config.type === 'minimax-portal-cn') {
|
||||
const accountApiKey = await getApiKey(config.id);
|
||||
if (accountApiKey) {
|
||||
apiKey = accountApiKey;
|
||||
} else {
|
||||
authHeader = true;
|
||||
apiKey = 'minimax-oauth';
|
||||
}
|
||||
} else if (config.type === 'qwen-portal') {
|
||||
const accountApiKey = await getApiKey(config.id);
|
||||
if (accountApiKey) {
|
||||
apiKey = accountApiKey;
|
||||
} else {
|
||||
apiKey = 'qwen-oauth';
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
baseUrl,
|
||||
api,
|
||||
models: [{ id: modelId, name: modelId }],
|
||||
apiKey,
|
||||
authHeader,
|
||||
};
|
||||
}
|
||||
|
||||
async function syncAgentModelsToRuntime(agentIds?: Set<string>): Promise<void> {
|
||||
const snapshot = await listAgentsSnapshot();
|
||||
const runtimeProviderConfigs = await buildRuntimeProviderConfigMap();
|
||||
|
||||
const targets = snapshot.agents.filter((agent) => {
|
||||
if (!agent.modelRef) return false;
|
||||
if (!agentIds) return true;
|
||||
return agentIds.has(agent.id);
|
||||
});
|
||||
|
||||
for (const agent of targets) {
|
||||
const parsed = parseModelRef(agent.modelRef || '');
|
||||
if (!parsed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const providerConfig = runtimeProviderConfigs.get(parsed.providerKey);
|
||||
if (!providerConfig) {
|
||||
logger.warn(
|
||||
`[provider-runtime] No provider account mapped to runtime key "${parsed.providerKey}" for agent "${agent.id}"`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const entry = await buildAgentModelProviderEntry(providerConfig, parsed.modelId);
|
||||
if (!entry) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await updateSingleAgentModelProvider(agent.id, parsed.providerKey, entry);
|
||||
}
|
||||
}
|
||||
|
||||
export async function syncAgentModelOverrideToRuntime(agentId: string): Promise<void> {
|
||||
await syncAgentModelsToRuntime(new Set([agentId]));
|
||||
}
|
||||
|
||||
export async function syncSavedProviderToRuntime(
|
||||
config: ProviderConfig,
|
||||
apiKey: string | undefined,
|
||||
@@ -346,6 +458,12 @@ export async function syncSavedProviderToRuntime(
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after provider save:', err);
|
||||
}
|
||||
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after saving provider "${context.runtimeProviderKey}" config`,
|
||||
@@ -388,6 +506,12 @@ export async function syncUpdatedProviderToRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after provider update:', err);
|
||||
}
|
||||
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after updating provider "${ock}" config`,
|
||||
@@ -496,6 +620,11 @@ export async function syncDefaultProviderToRuntime(
|
||||
|
||||
await setOpenClawDefaultModel(browserOAuthRuntimeProvider, modelOverride, fallbackModels);
|
||||
logger.info(`Configured openclaw.json for browser OAuth provider "${provider.id}"`);
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after browser OAuth switch:', err);
|
||||
}
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after provider switch to "${browserOAuthRuntimeProvider}"`,
|
||||
@@ -557,6 +686,12 @@ export async function syncDefaultProviderToRuntime(
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after default provider switch:', err);
|
||||
}
|
||||
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after provider switch to "${ock}"`,
|
||||
|
||||
@@ -81,6 +81,8 @@ export interface AgentSummary {
|
||||
name: string;
|
||||
isDefault: boolean;
|
||||
modelDisplay: string;
|
||||
modelRef: string | null;
|
||||
overrideModelRef: string | null;
|
||||
inheritedModel: boolean;
|
||||
workspace: string;
|
||||
agentDir: string;
|
||||
@@ -91,29 +93,38 @@ export interface AgentSummary {
|
||||
export interface AgentsSnapshot {
|
||||
agents: AgentSummary[];
|
||||
defaultAgentId: string;
|
||||
defaultModelRef: string | null;
|
||||
configuredChannelTypes: string[];
|
||||
channelOwners: Record<string, string>;
|
||||
channelAccountOwners: Record<string, string>;
|
||||
}
|
||||
|
||||
function formatModelLabel(model: unknown): string | null {
|
||||
function resolveModelRef(model: unknown): string | null {
|
||||
if (typeof model === 'string' && model.trim()) {
|
||||
const trimmed = model.trim();
|
||||
const parts = trimmed.split('/');
|
||||
return parts[parts.length - 1] || trimmed;
|
||||
return model.trim();
|
||||
}
|
||||
|
||||
if (model && typeof model === 'object') {
|
||||
const primary = (model as AgentModelConfig).primary;
|
||||
if (typeof primary === 'string' && primary.trim()) {
|
||||
const parts = primary.trim().split('/');
|
||||
return parts[parts.length - 1] || primary.trim();
|
||||
return primary.trim();
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function formatModelLabel(model: unknown): string | null {
|
||||
const modelRef = resolveModelRef(model);
|
||||
if (modelRef) {
|
||||
const trimmed = modelRef;
|
||||
const parts = trimmed.split('/');
|
||||
return parts[parts.length - 1] || trimmed;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function normalizeAgentName(name: string): string {
|
||||
return name.trim() || 'Agent';
|
||||
}
|
||||
@@ -487,10 +498,13 @@ async function buildSnapshotFromConfig(config: AgentConfigDocument): Promise<Age
|
||||
channelOwners[channelType] = primaryOwner;
|
||||
}
|
||||
|
||||
const defaultModelLabel = formatModelLabel((config.agents as AgentsConfig | undefined)?.defaults?.model);
|
||||
const defaultModelConfig = (config.agents as AgentsConfig | undefined)?.defaults?.model;
|
||||
const defaultModelLabel = formatModelLabel(defaultModelConfig);
|
||||
const defaultModelRef = resolveModelRef(defaultModelConfig);
|
||||
const agents: AgentSummary[] = entries.map((entry) => {
|
||||
const explicitModelRef = resolveModelRef(entry.model);
|
||||
const modelLabel = formatModelLabel(entry.model) || defaultModelLabel || 'Not configured';
|
||||
const inheritedModel = !formatModelLabel(entry.model) && Boolean(defaultModelLabel);
|
||||
const inheritedModel = !explicitModelRef && Boolean(defaultModelLabel);
|
||||
const entryIdNorm = normalizeAgentIdForBinding(entry.id);
|
||||
const ownedChannels = agentChannelSets.get(entryIdNorm) ?? new Set<string>();
|
||||
return {
|
||||
@@ -498,6 +512,8 @@ async function buildSnapshotFromConfig(config: AgentConfigDocument): Promise<Age
|
||||
name: entry.name || (entry.id === MAIN_AGENT_ID ? MAIN_AGENT_NAME : entry.id),
|
||||
isDefault: entry.id === defaultAgentId,
|
||||
modelDisplay: modelLabel,
|
||||
modelRef: explicitModelRef || defaultModelRef || null,
|
||||
overrideModelRef: explicitModelRef,
|
||||
inheritedModel,
|
||||
workspace: entry.workspace || (entry.id === MAIN_AGENT_ID ? getDefaultWorkspacePath(config) : `~/.openclaw/workspace-${entry.id}`),
|
||||
agentDir: entry.agentDir || getDefaultAgentDirPath(entry.id),
|
||||
@@ -511,6 +527,7 @@ async function buildSnapshotFromConfig(config: AgentConfigDocument): Promise<Age
|
||||
return {
|
||||
agents,
|
||||
defaultAgentId,
|
||||
defaultModelRef,
|
||||
configuredChannelTypes: configuredChannels.map((channelType) => toUiChannelType(channelType)),
|
||||
channelOwners,
|
||||
channelAccountOwners,
|
||||
@@ -598,6 +615,44 @@ export async function updateAgentName(agentId: string, name: string): Promise<Ag
|
||||
});
|
||||
}
|
||||
|
||||
function isValidModelRef(modelRef: string): boolean {
|
||||
const firstSlash = modelRef.indexOf('/');
|
||||
return firstSlash > 0 && firstSlash < modelRef.length - 1;
|
||||
}
|
||||
|
||||
export async function updateAgentModel(agentId: string, modelRef: string | null): Promise<AgentsSnapshot> {
|
||||
return withConfigLock(async () => {
|
||||
const config = await readOpenClawConfig() as AgentConfigDocument;
|
||||
const { agentsConfig, entries } = normalizeAgentsConfig(config);
|
||||
const index = entries.findIndex((entry) => entry.id === agentId);
|
||||
if (index === -1) {
|
||||
throw new Error(`Agent "${agentId}" not found`);
|
||||
}
|
||||
|
||||
const normalizedModelRef = typeof modelRef === 'string' ? modelRef.trim() : '';
|
||||
const nextEntry: AgentListEntry = { ...entries[index] };
|
||||
|
||||
if (!normalizedModelRef) {
|
||||
delete nextEntry.model;
|
||||
} else {
|
||||
if (!isValidModelRef(normalizedModelRef)) {
|
||||
throw new Error('modelRef must be in "provider/model" format');
|
||||
}
|
||||
nextEntry.model = { primary: normalizedModelRef };
|
||||
}
|
||||
|
||||
entries[index] = nextEntry;
|
||||
config.agents = {
|
||||
...agentsConfig,
|
||||
list: entries,
|
||||
};
|
||||
|
||||
await writeOpenClawConfig(config);
|
||||
logger.info('Updated agent model', { agentId, modelRef: normalizedModelRef || null });
|
||||
return buildSnapshotFromConfig(config);
|
||||
});
|
||||
}
|
||||
|
||||
export async function deleteAgentConfig(agentId: string): Promise<{ snapshot: AgentsSnapshot; removedEntry: AgentListEntry }> {
|
||||
return withConfigLock(async () => {
|
||||
if (agentId === MAIN_AGENT_ID) {
|
||||
|
||||
@@ -918,18 +918,20 @@ export async function syncSessionIdleMinutesToOpenClaw(): Promise<void> {
|
||||
/**
|
||||
* Update a provider entry in every discovered agent's models.json.
|
||||
*/
|
||||
export async function updateAgentModelProvider(
|
||||
type AgentModelProviderEntry = {
|
||||
baseUrl?: string;
|
||||
api?: string;
|
||||
models?: Array<{ id: string; name: string }>;
|
||||
apiKey?: string;
|
||||
/** When true, pi-ai sends Authorization: Bearer instead of x-api-key */
|
||||
authHeader?: boolean;
|
||||
};
|
||||
|
||||
async function updateModelsJsonProviderEntriesForAgents(
|
||||
agentIds: string[],
|
||||
providerType: string,
|
||||
entry: {
|
||||
baseUrl?: string;
|
||||
api?: string;
|
||||
models?: Array<{ id: string; name: string }>;
|
||||
apiKey?: string;
|
||||
/** When true, pi-ai sends Authorization: Bearer instead of x-api-key */
|
||||
authHeader?: boolean;
|
||||
}
|
||||
entry: AgentModelProviderEntry,
|
||||
): Promise<void> {
|
||||
const agentIds = await discoverAgentIds();
|
||||
for (const agentId of agentIds) {
|
||||
const modelsPath = join(homedir(), '.openclaw', 'agents', agentId, 'agent', 'models.json');
|
||||
let data: Record<string, unknown> = {};
|
||||
@@ -975,6 +977,26 @@ export async function updateAgentModelProvider(
|
||||
}
|
||||
}
|
||||
|
||||
export async function updateAgentModelProvider(
|
||||
providerType: string,
|
||||
entry: AgentModelProviderEntry,
|
||||
): Promise<void> {
|
||||
const agentIds = await discoverAgentIds();
|
||||
await updateModelsJsonProviderEntriesForAgents(agentIds, providerType, entry);
|
||||
}
|
||||
|
||||
export async function updateSingleAgentModelProvider(
|
||||
agentId: string,
|
||||
providerType: string,
|
||||
entry: AgentModelProviderEntry,
|
||||
): Promise<void> {
|
||||
const normalizedAgentId = agentId.trim();
|
||||
if (!normalizedAgentId) {
|
||||
throw new Error('agentId is required');
|
||||
}
|
||||
await updateModelsJsonProviderEntriesForAgents([normalizedAgentId], providerType, entry);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize ~/.openclaw/openclaw.json before Gateway start.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user