From 848098ffec7ccdb43f7c10eab9ed2bd7b09cf2ff Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Mon, 15 May 2023 12:09:42 +0200 Subject: [PATCH 1/2] feat: indicate if response will be streamable on routing.FindProviders --- routing/http/client/client_test.go | 6 +++--- routing/http/server/server.go | 10 ++++++++-- routing/http/server/server_test.go | 16 ++++++++-------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 05ad997af..7551350d3 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -27,8 +27,8 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { - args := m.Called(ctx, key) +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) { + args := m.Called(ctx, key, stream) return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { @@ -302,7 +302,7 @@ func TestClient_FindProviders(t *testing.T) { findProvsIter := iter.FromSlice(c.routerProvs) - router.On("FindProviders", mock.Anything, cid). + router.On("FindProviders", mock.Anything, cid, c.expStreamingResponse). Return(findProvsIter, c.routerErr) provsIter, err := client.FindProviders(ctx, cid) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index d2c7f5221..1dd29277a 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -41,7 +41,9 @@ type FindProvidersAsyncResponse struct { } type ContentRouter interface { - FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) + // FindProviders searches for peers who are able to provide a given key. Stream + // indicates whether or not this request will be responded as a stream. + FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -170,9 +172,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { var supportsNDJSON bool var supportsJSON bool + var streaming bool acceptHeaders := httpReq.Header.Values("Accept") if len(acceptHeaders) == 0 { handlerFunc = s.findProvidersJSON + streaming = false } else { for _, acceptHeader := range acceptHeaders { for _, accept := range strings.Split(acceptHeader, ",") { @@ -193,15 +197,17 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { if supportsNDJSON && !s.disableNDJSON { handlerFunc = s.findProvidersNDJSON + streaming = true } else if supportsJSON { handlerFunc = s.findProvidersJSON + streaming = false } else { writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types")) return } } - provIter, err := s.svc.FindProviders(httpReq.Context(), cid) + provIter, err := s.svc.FindProviders(httpReq.Context(), cid, streaming) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 6e7d4ba9f..acab26f18 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) { cb, err := cid.Decode(c) require.NoError(t, err) - router.On("FindProviders", mock.Anything, cb). + router.On("FindProviders", mock.Anything, cb, false). Return(results, nil) resp, err := http.Get(serverAddr + ProvidePath + c) @@ -63,7 +63,7 @@ func TestResponse(t *testing.T) { cid, err := cid.Decode(cidStr) require.NoError(t, err) - runTest := func(t *testing.T, contentType string, expected string) { + runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) { t.Parallel() results := iter.FromSlice([]iter.Result[types.ProviderResponse]{ @@ -85,7 +85,7 @@ func TestResponse(t *testing.T) { server := httptest.NewServer(Handler(router)) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - router.On("FindProviders", mock.Anything, cid).Return(results, nil) + router.On("FindProviders", mock.Anything, cid, expectedStream).Return(results, nil) urlStr := serverAddr + ProvidePath + cidStr req, err := http.NewRequest(http.MethodGet, urlStr, nil) @@ -101,22 +101,22 @@ func TestResponse(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) - require.Equal(t, string(body), expected) + require.Equal(t, string(body), expectedBody) } t.Run("JSON Response", func(t *testing.T) { - runTest(t, mediaTypeJSON, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]},{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}]}`) + runTest(t, mediaTypeJSON, false, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]},{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}]}`) }) t.Run("NDJSON Response", func(t *testing.T) { - runTest(t, mediaTypeNDJSON, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n"+`{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}`+"\n") + runTest(t, mediaTypeNDJSON, true, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n"+`{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}`+"\n") }) } type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { - args := m.Called(ctx, key) +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) { + args := m.Called(ctx, key, stream) return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) { From 06f2d96c75c8f1c86a8db46112fb2af10d97bf53 Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Wed, 24 May 2023 12:12:41 +0200 Subject: [PATCH 2/2] refactor: change FindProviders to use "limit int" instead of "stream bool" --- routing/http/client/client_test.go | 11 +++++--- routing/http/server/server.go | 45 ++++++++++++++++++++++-------- routing/http/server/server_test.go | 12 +++++--- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 7551350d3..880fa33e1 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -27,8 +27,8 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) { - args := m.Called(ctx, key, stream) +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) { + args := m.Called(ctx, key, limit) return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { @@ -302,8 +302,11 @@ func TestClient_FindProviders(t *testing.T) { findProvsIter := iter.FromSlice(c.routerProvs) - router.On("FindProviders", mock.Anything, cid, c.expStreamingResponse). - Return(findProvsIter, c.routerErr) + if c.expStreamingResponse { + router.On("FindProviders", mock.Anything, cid, 0).Return(findProvsIter, c.routerErr) + } else { + router.On("FindProviders", mock.Anything, cid, 20).Return(findProvsIter, c.routerErr) + } provsIter, err := client.FindProviders(ctx, cid) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 1dd29277a..47c075f0a 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -28,6 +28,9 @@ const ( mediaTypeJSON = "application/json" mediaTypeNDJSON = "application/x-ndjson" mediaTypeWildcard = "*/*" + + DefaultRecordsLimit = 20 + DefaultStreamingRecordsLimit = 0 ) var logger = logging.Logger("service/server/delegatedrouting") @@ -41,9 +44,9 @@ type FindProvidersAsyncResponse struct { } type ContentRouter interface { - // FindProviders searches for peers who are able to provide a given key. Stream - // indicates whether or not this request will be responded as a stream. - FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) + // FindProviders searches for peers who are able to provide a given key. Limit + // indicates the maximum amount of results to return. 0 means unbounded. + FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -71,9 +74,27 @@ func WithStreamingResultsDisabled() Option { } } +// WithRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders +// for non-streaming requests (application/json). Default is DefaultRecordsLimit. +func WithRecordsLimit(limit int) Option { + return func(s *server) { + s.recordsLimit = limit + } +} + +// WithStreamingRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders +// for streaming requests (application/x-ndjson). Default is DefaultStreamingRecordsLimit. +func WithStreamingRecordsLimit(limit int) Option { + return func(s *server) { + s.streamingRecordsLimit = limit + } +} + func Handler(svc ContentRouter, opts ...Option) http.Handler { server := &server{ - svc: svc, + svc: svc, + recordsLimit: DefaultRecordsLimit, + streamingRecordsLimit: DefaultStreamingRecordsLimit, } for _, opt := range opts { @@ -88,8 +109,10 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler { } type server struct { - svc ContentRouter - disableNDJSON bool + svc ContentRouter + disableNDJSON bool + recordsLimit int + streamingRecordsLimit int } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { @@ -172,11 +195,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { var supportsNDJSON bool var supportsJSON bool - var streaming bool + var recordsLimit int acceptHeaders := httpReq.Header.Values("Accept") if len(acceptHeaders) == 0 { handlerFunc = s.findProvidersJSON - streaming = false + recordsLimit = s.recordsLimit } else { for _, acceptHeader := range acceptHeaders { for _, accept := range strings.Split(acceptHeader, ",") { @@ -197,17 +220,17 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { if supportsNDJSON && !s.disableNDJSON { handlerFunc = s.findProvidersNDJSON - streaming = true + recordsLimit = s.streamingRecordsLimit } else if supportsJSON { handlerFunc = s.findProvidersJSON - streaming = false + recordsLimit = s.recordsLimit } else { writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types")) return } } - provIter, err := s.svc.FindProviders(httpReq.Context(), cid, streaming) + provIter, err := s.svc.FindProviders(httpReq.Context(), cid, recordsLimit) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index acab26f18..69db7d556 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) { cb, err := cid.Decode(c) require.NoError(t, err) - router.On("FindProviders", mock.Anything, cb, false). + router.On("FindProviders", mock.Anything, cb, DefaultRecordsLimit). Return(results, nil) resp, err := http.Get(serverAddr + ProvidePath + c) @@ -85,7 +85,11 @@ func TestResponse(t *testing.T) { server := httptest.NewServer(Handler(router)) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - router.On("FindProviders", mock.Anything, cid, expectedStream).Return(results, nil) + limit := DefaultRecordsLimit + if expectedStream { + limit = DefaultStreamingRecordsLimit + } + router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil) urlStr := serverAddr + ProvidePath + cidStr req, err := http.NewRequest(http.MethodGet, urlStr, nil) @@ -115,8 +119,8 @@ func TestResponse(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) { - args := m.Called(ctx, key, stream) +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) { + args := m.Called(ctx, key, limit) return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {