Skip to content

Commit 60df06f

Browse files
committed
internal/mcp: add pagination for tools
This CL adds paginating functionality for tools. It uses the gob encoder for creating opaque cursors. More CLs to follow for paginating other features. Change-Id: I1443c0213ceb6238d844a8eee9a0be52934f5cab Reviewed-on: https://go-review.googlesource.com/c/tools/+/675055 Reviewed-by: Robert Findley <rfindley@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
1 parent f5ea575 commit 60df06f

File tree

2 files changed

+253
-1
lines changed

2 files changed

+253
-1
lines changed

internal/mcp/server.go

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package mcp
66

77
import (
8+
"bytes"
89
"context"
10+
"encoding/base64"
11+
"encoding/gob"
912
"fmt"
1013
"iter"
1114
"net/url"
@@ -16,6 +19,8 @@ import (
1619
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
1720
)
1821

22+
const DefaultPageSize = 1000
23+
1924
// A Server is an instance of an MCP server.
2025
//
2126
// Servers expose server-side MCP features, which can serve one or more MCP
@@ -40,6 +45,9 @@ type ServerOptions struct {
4045
Instructions string
4146
// If non-nil, called when "notifications/intialized" is received.
4247
InitializedHandler func(context.Context, *ServerSession, *InitializedParams)
48+
// PageSize is the maximum number of items to return in a single page for
49+
// list methods (e.g. ListTools).
50+
PageSize int
4351
// If non-nil, called when "notifications/roots/list_changed" is received.
4452
RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams)
4553
}
@@ -55,6 +63,12 @@ func NewServer(name, version string, opts *ServerOptions) *Server {
5563
if opts == nil {
5664
opts = new(ServerOptions)
5765
}
66+
if opts.PageSize < 0 {
67+
panic(fmt.Errorf("invalid page size %d", opts.PageSize))
68+
}
69+
if opts.PageSize == 0 {
70+
opts.PageSize = DefaultPageSize
71+
}
5872
return &Server{
5973
name: name,
6074
version: version,
@@ -186,8 +200,17 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr
186200
func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) {
187201
s.mu.Lock()
188202
defer s.mu.Unlock()
203+
var cursor string
204+
if params != nil {
205+
cursor = params.Cursor
206+
}
207+
tools, nextCursor, err := paginateList(s.tools, cursor, s.opts.PageSize)
208+
if err != nil {
209+
return nil, err
210+
}
189211
res := new(ListToolsResult)
190-
for t := range s.tools.all() {
212+
res.NextCursor = nextCursor
213+
for _, t := range tools {
191214
res.Tools = append(res.Tools, t.Tool)
192215
}
193216
return res, nil
@@ -509,3 +532,71 @@ func (ss *ServerSession) Close() error {
509532
func (ss *ServerSession) Wait() error {
510533
return ss.conn.Wait()
511534
}
535+
536+
// pageToken is the internal structure for the opaque pagination cursor.
537+
// It will be Gob-encoded and then Base64-encoded for use as a string token.
538+
type pageToken struct {
539+
LastUID string // The unique ID of the last resource seen.
540+
}
541+
542+
// paginateList returns a slice of features from the given featureSet, based on
543+
// the provided cursor and page size. It also returns a new cursor for the next
544+
// page, or an empty string if there are no more pages.
545+
func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (features []T, nextCursor string, err error) {
546+
encodeCursor := func(uid string) (string, error) {
547+
var buf bytes.Buffer
548+
token := pageToken{LastUID: uid}
549+
encoder := gob.NewEncoder(&buf)
550+
if err := encoder.Encode(token); err != nil {
551+
return "", fmt.Errorf("failed to encode page token: %w", err)
552+
}
553+
return base64.URLEncoding.EncodeToString(buf.Bytes()), nil
554+
}
555+
556+
decodeCursor := func(cursor string) (*pageToken, error) {
557+
decodedBytes, err := base64.URLEncoding.DecodeString(cursor)
558+
if err != nil {
559+
return nil, fmt.Errorf("failed to decode cursor: %w", err)
560+
}
561+
562+
var token pageToken
563+
buf := bytes.NewBuffer(decodedBytes)
564+
decoder := gob.NewDecoder(buf)
565+
if err := decoder.Decode(&token); err != nil {
566+
return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor)
567+
}
568+
return &token, nil
569+
}
570+
571+
var seq iter.Seq[T]
572+
if cursor == "" {
573+
seq = fs.all()
574+
} else {
575+
pageToken, err := decodeCursor(cursor)
576+
// According to the spec, invalid cursors should return Invalid params.
577+
if err != nil {
578+
return nil, "", jsonrpc2.ErrInvalidParams
579+
}
580+
seq = fs.above(pageToken.LastUID)
581+
}
582+
var count int
583+
for f := range seq {
584+
count++
585+
// If we've seen pageSize + 1 elements, we've gathered enough info to determine
586+
// if there's a next page. Stop processing the sequence.
587+
if count == pageSize+1 {
588+
break
589+
}
590+
features = append(features, f)
591+
}
592+
// No remaining pages.
593+
if count < pageSize+1 {
594+
return features, "", nil
595+
}
596+
// Trim the extra element from the result.
597+
nextCursor, err = encodeCursor(fs.uniqueID(features[len(features)-1]))
598+
if err != nil {
599+
return nil, "", err
600+
}
601+
return features, nextCursor, nil
602+
}

internal/mcp/server_example_test.go

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@ import (
88
"context"
99
"fmt"
1010
"log"
11+
"slices"
12+
"testing"
1113

14+
"github.com/google/go-cmp/cmp"
15+
"github.com/google/go-cmp/cmp/cmpopts"
1216
"golang.org/x/tools/internal/mcp"
17+
"golang.org/x/tools/internal/mcp/jsonschema"
1318
)
1419

1520
type SayHiParams struct {
@@ -51,3 +56,159 @@ func ExampleServer() {
5156

5257
// Output: Hi user
5358
}
59+
60+
func TestListTool(t *testing.T) {
61+
toolA := mcp.NewTool("apple", "apple tool", SayHi)
62+
toolB := mcp.NewTool("banana", "banana tool", SayHi)
63+
toolC := mcp.NewTool("cherry", "cherry tool", SayHi)
64+
testCases := []struct {
65+
tools []*mcp.ServerTool
66+
want []*mcp.Tool
67+
pageSize int
68+
}{
69+
{
70+
// Simple test.
71+
[]*mcp.ServerTool{toolA, toolB, toolC},
72+
[]*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool},
73+
mcp.DefaultPageSize,
74+
},
75+
{
76+
// Tools should be ordered by tool name.
77+
[]*mcp.ServerTool{toolC, toolA, toolB},
78+
[]*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool},
79+
mcp.DefaultPageSize,
80+
},
81+
{
82+
// Page size of 1 should yield the first tool only.
83+
[]*mcp.ServerTool{toolC, toolA, toolB},
84+
[]*mcp.Tool{toolA.Tool},
85+
1,
86+
},
87+
{
88+
// Page size of 2 should yield the first 2 tools only.
89+
[]*mcp.ServerTool{toolC, toolA, toolB},
90+
[]*mcp.Tool{toolA.Tool, toolB.Tool},
91+
2,
92+
},
93+
{
94+
// Page size of 3 should yield all tools.
95+
[]*mcp.ServerTool{toolC, toolA, toolB},
96+
[]*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool},
97+
3,
98+
},
99+
{
100+
[]*mcp.ServerTool{},
101+
nil,
102+
1,
103+
},
104+
}
105+
ctx := context.Background()
106+
for _, tc := range testCases {
107+
server := mcp.NewServer("server", "v0.0.1", &mcp.ServerOptions{PageSize: tc.pageSize})
108+
server.AddTools(tc.tools...)
109+
clientTransport, serverTransport := mcp.NewInMemoryTransports()
110+
serverSession, err := server.Connect(ctx, serverTransport)
111+
if err != nil {
112+
log.Fatal(err)
113+
}
114+
client := mcp.NewClient("client", "v0.0.1", nil)
115+
clientSession, err := client.Connect(ctx, clientTransport)
116+
if err != nil {
117+
log.Fatal(err)
118+
}
119+
res, err := clientSession.ListTools(ctx, nil)
120+
serverSession.Close()
121+
clientSession.Close()
122+
if err != nil {
123+
log.Fatal(err)
124+
}
125+
if len(res.Tools) != len(tc.want) {
126+
t.Fatalf("expected %d tools, got %d", len(tc.want), len(res.Tools))
127+
}
128+
if diff := cmp.Diff(res.Tools, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
129+
t.Fatalf("expected tools %+v, got %+v", tc.want, res.Tools)
130+
}
131+
if tc.pageSize < len(tc.tools) && res.NextCursor == "" {
132+
t.Fatalf("expected next cursor, got none")
133+
}
134+
}
135+
}
136+
137+
func TestListToolPaginateInvalidCursor(t *testing.T) {
138+
toolA := mcp.NewTool("apple", "apple tool", SayHi)
139+
ctx := context.Background()
140+
server := mcp.NewServer("server", "v0.0.1", nil)
141+
server.AddTools(toolA)
142+
clientTransport, serverTransport := mcp.NewInMemoryTransports()
143+
serverSession, err := server.Connect(ctx, serverTransport)
144+
if err != nil {
145+
log.Fatal(err)
146+
}
147+
client := mcp.NewClient("client", "v0.0.1", nil)
148+
clientSession, err := client.Connect(ctx, clientTransport)
149+
if err != nil {
150+
log.Fatal(err)
151+
}
152+
_, err = clientSession.ListTools(ctx, &mcp.ListToolsParams{Cursor: "invalid"})
153+
if err == nil {
154+
t.Fatalf("expected error, got none")
155+
}
156+
serverSession.Close()
157+
clientSession.Close()
158+
}
159+
160+
func TestListToolPaginate(t *testing.T) {
161+
serverTools := []*mcp.ServerTool{
162+
mcp.NewTool("apple", "apple tool", SayHi),
163+
mcp.NewTool("banana", "banana tool", SayHi),
164+
mcp.NewTool("cherry", "cherry tool", SayHi),
165+
mcp.NewTool("durian", "durian tool", SayHi),
166+
mcp.NewTool("elderberry", "elderberry tool", SayHi),
167+
}
168+
var wantTools []*mcp.Tool
169+
for _, tool := range serverTools {
170+
wantTools = append(wantTools, tool.Tool)
171+
}
172+
ctx := context.Background()
173+
// Try all possible page sizes, ensuring we get the correct list of tools.
174+
for pageSize := 1; pageSize < len(serverTools)+1; pageSize++ {
175+
server := mcp.NewServer("server", "v0.0.1", &mcp.ServerOptions{PageSize: pageSize})
176+
server.AddTools(serverTools...)
177+
clientTransport, serverTransport := mcp.NewInMemoryTransports()
178+
serverSession, err := server.Connect(ctx, serverTransport)
179+
if err != nil {
180+
log.Fatal(err)
181+
}
182+
client := mcp.NewClient("client", "v0.0.1", nil)
183+
clientSession, err := client.Connect(ctx, clientTransport)
184+
if err != nil {
185+
log.Fatal(err)
186+
}
187+
var gotTools []*mcp.Tool
188+
var nextCursor string
189+
wantChunks := slices.Collect(slices.Chunk(wantTools, pageSize))
190+
index := 0
191+
// Iterate through all pages, comparing sub-slices to the paginated list.
192+
for {
193+
res, err := clientSession.ListTools(ctx, &mcp.ListToolsParams{Cursor: nextCursor})
194+
if err != nil {
195+
log.Fatal(err)
196+
}
197+
gotTools = append(gotTools, res.Tools...)
198+
nextCursor = res.NextCursor
199+
if diff := cmp.Diff(res.Tools, wantChunks[index], cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
200+
t.Errorf("expected %v, got %v, (-want +got):\n%s", wantChunks[index], res.Tools, diff)
201+
}
202+
if res.NextCursor == "" {
203+
break
204+
}
205+
index++
206+
}
207+
serverSession.Close()
208+
clientSession.Close()
209+
210+
if len(gotTools) != len(wantTools) {
211+
t.Fatalf("expected %d tools, got %d", len(wantTools), len(gotTools))
212+
}
213+
}
214+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy