SpringAi自定义Mcp客户端选择工具调用

总所周知在SpringAi中MCP Client默认的mcp接入方式是在application.yml里面配置mcp服务列表后,对话每次都会调用所有的mcp服务,这是非常慢又非常不自定义的一种方法;有没有方式可以在每次对话时我可以自己选择我要用的mcp服务呢?有的兄弟,有的;

首先删掉原来application配置的mcp client,新建一个Mcp列表配置类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@Configuration
public class McpListConfig {

//sse
@Bean
@Primary
public List<ToolCallback> baiduMap() {
McpClientTransport transport = HttpClientSseClientTransport
.builder("https://mcp.map.baidu.com")
.sseEndpoint("/sse?ak=")
.objectMapper(new ObjectMapper())
.build();
McpAsyncClient client = McpClient.async(transport).initializationTimeout(Duration.ofSeconds(30)).build();
client.initialize().block();
return McpToolUtils.getToolCallbacksFromAsyncClients(client);
}

//stdio
@Bean
public List<ToolCallback> zhiPuAi() {
ServerParameters serverParameters = new ServerParameters.Builder("cmd")
.arg("/")
.env(new HashMap<>())
.build();
McpClientTransport transport = new StdioClientTransport(serverParameters);
McpAsyncClient client = McpClient.async(transport).initializationTimeout(Duration.ofSeconds(30)).build();
client.initialize().block();
return McpToolUtils.getToolCallbacksFromAsyncClients(client);
}

}

这里有两种方法引入,分别是sse和stdio模式;具体构建方法可以参考章末尾源码;其实就是调用了一个工具类将McpAsyncClient转为了List

这里为什么要在第一个加上Primary注解呢?因为它两个返回的类型都是一样的,加个Primary就可以默认加载第一个,详细请阅读Spring中强大的@Bean注解

然后再在对话中导入你需要用到的mcp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private final AiConfig aiConfig;

//导入对应的方法名
List<ToolCallback> toolCallbacksFromAsyncClients = mcpListConfig.baiduMap();

ChatClient.builder(model).defaultSystem(role)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build())
.defaultToolCallbacks(toolCallbackProvider)
.build()
.prompt()
.user(messageDto.getMessage())
.advisors(advisorSpec -> advisorSpec.param(CONVERSATION_ID, messageDto.getChatId()))
.stream()
.content();

HttpClientSseClientTransport的构造方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
public static class Builder {
private String baseUri;
private String sseEndpoint = "/sse";
private HttpClient.Builder clientBuilder;
private ObjectMapper objectMapper;
private HttpRequest.Builder requestBuilder;

Builder() {
this.clientBuilder = HttpClient.newBuilder().version(Version.HTTP_1_1).connectTimeout(Duration.ofSeconds(10L));
this.objectMapper = new ObjectMapper();
this.requestBuilder = HttpRequest.newBuilder().header("Content-Type", "application/json");
}

/** @deprecated */
@Deprecated(
forRemoval = true
)
public Builder(String baseUri) {
this.clientBuilder = HttpClient.newBuilder().version(Version.HTTP_1_1).connectTimeout(Duration.ofSeconds(10L));
this.objectMapper = new ObjectMapper();
this.requestBuilder = HttpRequest.newBuilder().header("Content-Type", "application/json");
Assert.hasText(baseUri, "baseUri must not be empty");
this.baseUri = baseUri;
}

Builder baseUri(String baseUri) {
Assert.hasText(baseUri, "baseUri must not be empty");
this.baseUri = baseUri;
return this;
}

public Builder sseEndpoint(String sseEndpoint) {
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
this.sseEndpoint = sseEndpoint;
return this;
}

public Builder clientBuilder(HttpClient.Builder clientBuilder) {
Assert.notNull(clientBuilder, "clientBuilder must not be null");
this.clientBuilder = clientBuilder;
return this;
}

public Builder customizeClient(final Consumer<HttpClient.Builder> clientCustomizer) {
Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
clientCustomizer.accept(this.clientBuilder);
return this;
}

public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
Assert.notNull(requestBuilder, "requestBuilder must not be null");
this.requestBuilder = requestBuilder;
return this;
}

public Builder customizeRequest(final Consumer<HttpRequest.Builder> requestCustomizer) {
Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
requestCustomizer.accept(this.requestBuilder);
return this;
}

public Builder objectMapper(ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "objectMapper must not be null");
this.objectMapper = objectMapper;
return this;
}

public HttpClientSseClientTransport build() {
return new HttpClientSseClientTransport(this.clientBuilder.build(), this.requestBuilder, this.baseUri, this.sseEndpoint, this.objectMapper);
}
}

ServerParameters的构造方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
public static class Builder {
private String command;
private List<String> args = new ArrayList();
private Map<String, String> env = new HashMap();

public Builder(String command) {
Assert.notNull(command, "The command can not be null");
this.command = command;
}

public Builder args(String... args) {
Assert.notNull(args, "The args can not be null");
this.args = Arrays.asList(args);
return this;
}

public Builder args(List<String> args) {
Assert.notNull(args, "The args can not be null");
this.args = new ArrayList(args);
return this;
}

public Builder arg(String arg) {
Assert.notNull(arg, "The arg can not be null");
this.args.add(arg);
return this;
}

public Builder env(Map<String, String> env) {
if (env != null && !env.isEmpty()) {
this.env.putAll(env);
}

return this;
}

public Builder addEnvVar(String key, String value) {
Assert.notNull(key, "The key can not be null");
Assert.notNull(value, "The value can not be null");
this.env.put(key, value);
return this;
}

public ServerParameters build() {
return new ServerParameters(this.command, this.args, this.env);
}
}