feat(agent-model): add per-agent model override with default-reset UX and runtime sync (#651)
This commit is contained in:
@@ -12,8 +12,10 @@ import {
|
||||
setOpenClawDefaultModelWithOverride,
|
||||
syncProviderConfigToOpenClaw,
|
||||
updateAgentModelProvider,
|
||||
updateSingleAgentModelProvider,
|
||||
} from '../../utils/openclaw-auth';
|
||||
import { logger } from '../../utils/logger';
|
||||
import { listAgentsSnapshot } from '../../utils/agent-config';
|
||||
|
||||
const GOOGLE_OAUTH_RUNTIME_PROVIDER = 'google-gemini-cli';
|
||||
const GOOGLE_OAUTH_DEFAULT_MODEL_REF = `${GOOGLE_OAUTH_RUNTIME_PROVIDER}/gemini-3-pro-preview`;
|
||||
@@ -336,6 +338,116 @@ async function syncProviderToRuntime(
|
||||
return context;
|
||||
}
|
||||
|
||||
function parseModelRef(modelRef: string): { providerKey: string; modelId: string } | null {
|
||||
const trimmed = modelRef.trim();
|
||||
const separatorIndex = trimmed.indexOf('/');
|
||||
if (separatorIndex <= 0 || separatorIndex >= trimmed.length - 1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
providerKey: trimmed.slice(0, separatorIndex),
|
||||
modelId: trimmed.slice(separatorIndex + 1),
|
||||
};
|
||||
}
|
||||
|
||||
async function buildRuntimeProviderConfigMap(): Promise<Map<string, ProviderConfig>> {
|
||||
const configs = await getAllProviders();
|
||||
const runtimeMap = new Map<string, ProviderConfig>();
|
||||
|
||||
for (const config of configs) {
|
||||
const runtimeKey = await resolveRuntimeProviderKey(config);
|
||||
runtimeMap.set(runtimeKey, config);
|
||||
}
|
||||
|
||||
return runtimeMap;
|
||||
}
|
||||
|
||||
async function buildAgentModelProviderEntry(
|
||||
config: ProviderConfig,
|
||||
modelId: string,
|
||||
): Promise<{
|
||||
baseUrl?: string;
|
||||
api?: string;
|
||||
models?: Array<{ id: string; name: string }>;
|
||||
apiKey?: string;
|
||||
authHeader?: boolean;
|
||||
} | null> {
|
||||
const meta = getProviderConfig(config.type);
|
||||
const api = config.apiProtocol || (config.type === 'custom' ? 'openai-completions' : meta?.api);
|
||||
const baseUrl = normalizeProviderBaseUrl(config, config.baseUrl || meta?.baseUrl, api);
|
||||
if (!api || !baseUrl) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let apiKey: string | undefined;
|
||||
let authHeader: boolean | undefined;
|
||||
|
||||
if (config.type === 'custom') {
|
||||
apiKey = (await getApiKey(config.id)) || undefined;
|
||||
} else if (config.type === 'minimax-portal' || config.type === 'minimax-portal-cn') {
|
||||
const accountApiKey = await getApiKey(config.id);
|
||||
if (accountApiKey) {
|
||||
apiKey = accountApiKey;
|
||||
} else {
|
||||
authHeader = true;
|
||||
apiKey = 'minimax-oauth';
|
||||
}
|
||||
} else if (config.type === 'qwen-portal') {
|
||||
const accountApiKey = await getApiKey(config.id);
|
||||
if (accountApiKey) {
|
||||
apiKey = accountApiKey;
|
||||
} else {
|
||||
apiKey = 'qwen-oauth';
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
baseUrl,
|
||||
api,
|
||||
models: [{ id: modelId, name: modelId }],
|
||||
apiKey,
|
||||
authHeader,
|
||||
};
|
||||
}
|
||||
|
||||
async function syncAgentModelsToRuntime(agentIds?: Set<string>): Promise<void> {
|
||||
const snapshot = await listAgentsSnapshot();
|
||||
const runtimeProviderConfigs = await buildRuntimeProviderConfigMap();
|
||||
|
||||
const targets = snapshot.agents.filter((agent) => {
|
||||
if (!agent.modelRef) return false;
|
||||
if (!agentIds) return true;
|
||||
return agentIds.has(agent.id);
|
||||
});
|
||||
|
||||
for (const agent of targets) {
|
||||
const parsed = parseModelRef(agent.modelRef || '');
|
||||
if (!parsed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const providerConfig = runtimeProviderConfigs.get(parsed.providerKey);
|
||||
if (!providerConfig) {
|
||||
logger.warn(
|
||||
`[provider-runtime] No provider account mapped to runtime key "${parsed.providerKey}" for agent "${agent.id}"`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const entry = await buildAgentModelProviderEntry(providerConfig, parsed.modelId);
|
||||
if (!entry) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await updateSingleAgentModelProvider(agent.id, parsed.providerKey, entry);
|
||||
}
|
||||
}
|
||||
|
||||
export async function syncAgentModelOverrideToRuntime(agentId: string): Promise<void> {
|
||||
await syncAgentModelsToRuntime(new Set([agentId]));
|
||||
}
|
||||
|
||||
export async function syncSavedProviderToRuntime(
|
||||
config: ProviderConfig,
|
||||
apiKey: string | undefined,
|
||||
@@ -346,6 +458,12 @@ export async function syncSavedProviderToRuntime(
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after provider save:', err);
|
||||
}
|
||||
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after saving provider "${context.runtimeProviderKey}" config`,
|
||||
@@ -388,6 +506,12 @@ export async function syncUpdatedProviderToRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after provider update:', err);
|
||||
}
|
||||
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after updating provider "${ock}" config`,
|
||||
@@ -496,6 +620,11 @@ export async function syncDefaultProviderToRuntime(
|
||||
|
||||
await setOpenClawDefaultModel(browserOAuthRuntimeProvider, modelOverride, fallbackModels);
|
||||
logger.info(`Configured openclaw.json for browser OAuth provider "${provider.id}"`);
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after browser OAuth switch:', err);
|
||||
}
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after provider switch to "${browserOAuthRuntimeProvider}"`,
|
||||
@@ -557,6 +686,12 @@ export async function syncDefaultProviderToRuntime(
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
await syncAgentModelsToRuntime();
|
||||
} catch (err) {
|
||||
logger.warn('[provider-runtime] Failed to sync per-agent model registries after default provider switch:', err);
|
||||
}
|
||||
|
||||
scheduleGatewayRefresh(
|
||||
gatewayManager,
|
||||
`Scheduling Gateway reload after provider switch to "${ock}"`,
|
||||
|
||||
Reference in New Issue
Block a user