Skip to content

Commit feb84c6

Browse files
authored
Allow users to specify their own client implementation used by the library (#73)
* Allow users to specify their own client implementation used by the library * fix typos * add tests
1 parent 0488917 commit feb84c6

File tree

2 files changed

+135
-9
lines changed

2 files changed

+135
-9
lines changed

common_client.go

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,20 @@ func sendRequest(c *http.Client, req *http.Request) (*http.Response, error) {
7575
}
7676

7777
type httpClientOption struct {
78-
logger *slog.Logger
79-
retryMax int
80-
retryWaitMax time.Duration
78+
logger *slog.Logger
79+
80+
// Options for built-in retryable HTTP client.
81+
// Ignored if a custom retryable HTTP client is provided via WithRetryableHTTPClint.
82+
retryMax int
83+
retryWaitMax time.Duration
84+
85+
// fields added to the transport if specified
8186
rootCAs *x509.CertPool
8287
tlsInsecureSkipVerify bool
8388
proxyFunc ProxyFunc
89+
timeout time.Duration
90+
91+
retryableHTTPClient *retryablehttp.Client
8492
}
8593

8694
func (o *httpClientOption) defaults() {
@@ -93,14 +101,28 @@ func (o *httpClientOption) defaults() {
93101
if o.retryWaitMax == 0 {
94102
o.retryWaitMax = 30 * time.Second
95103
}
104+
if o.timeout == 0 {
105+
o.timeout = 5 * time.Minute
106+
}
96107
}
97108

98109
func (o *httpClientOption) newRetryableHTTPClient() (*retryablehttp.Client, error) {
99-
retryClient := retryablehttp.NewClient()
100-
retryClient.Logger = o.logger
101-
retryClient.RetryMax = o.retryMax
102-
retryClient.RetryWaitMax = o.retryWaitMax
103-
retryClient.HTTPClient.Timeout = 5 * time.Minute // timeout must be > 1m to accomodate long polling
110+
var retryClient *retryablehttp.Client
111+
if o.retryableHTTPClient != nil {
112+
retryClient = o.retryableHTTPClient
113+
} else {
114+
retryClient = retryablehttp.NewClient()
115+
retryClient.RetryMax = o.retryMax
116+
retryClient.RetryWaitMax = o.retryWaitMax
117+
}
118+
119+
if retryClient.HTTPClient.Timeout == 0 {
120+
retryClient.HTTPClient.Timeout = o.timeout
121+
}
122+
123+
if retryClient.Logger == nil {
124+
retryClient.Logger = o.logger
125+
}
104126

105127
transport, ok := retryClient.HTTPClient.Transport.(*http.Transport)
106128
if !ok {
@@ -120,7 +142,9 @@ func (o *httpClientOption) newRetryableHTTPClient() (*retryablehttp.Client, erro
120142
transport.TLSClientConfig.InsecureSkipVerify = true
121143
}
122144

123-
transport.Proxy = o.proxyFunc
145+
if o.proxyFunc != nil {
146+
transport.Proxy = o.proxyFunc
147+
}
124148

125149
retryClient.HTTPClient.Transport = transport
126150

@@ -145,6 +169,14 @@ func (c *commonClient) setUserAgent() {
145169
// HTTPOption defines a functional option for configuring the Client.
146170
type HTTPOption func(*httpClientOption)
147171

172+
// WithRetryableHTTPClint allows users to provide a custom retryable HTTP client.
173+
// If not set, a default client will be used with the specified retry and timeout settings.
174+
func WithRetryableHTTPClint(client *retryablehttp.Client) HTTPOption {
175+
return func(c *httpClientOption) {
176+
c.retryableHTTPClient = client
177+
}
178+
}
179+
148180
// WithLogger sets a custom logger for the Client.
149181
// If nil is passed, a discard logger will be used.
150182
func WithLogger(logger *slog.Logger) HTTPOption {
@@ -190,3 +222,10 @@ func WithProxy(proxyFunc ProxyFunc) HTTPOption {
190222
c.proxyFunc = proxyFunc
191223
}
192224
}
225+
226+
// WithTimeout sets a timeout for the Client.
227+
func WithTimeout(duration time.Duration) HTTPOption {
228+
return func(c *httpClientOption) {
229+
c.timeout = duration
230+
}
231+
}

common_client_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ import (
77
"net/http/httptest"
88
"net/url"
99
"testing"
10+
"time"
1011

1112
"github.com/actions/scaleset/internal/testserver"
13+
"github.com/hashicorp/go-retryablehttp"
1214
"github.com/stretchr/testify/assert"
1315
"github.com/stretchr/testify/require"
1416
"golang.org/x/net/http/httpproxy"
@@ -153,3 +155,88 @@ func TestUserAgent(t *testing.T) {
153155

154156
assert.Equal(t, want, got)
155157
}
158+
159+
// TestWithRetryableHTTPClient verifies that a custom retryable HTTP client
160+
// provided via WithRetryableHTTPClient is actually used instead of the built-in one
161+
func TestWithRetryableHTTPClient(t *testing.T) {
162+
t.Run("uses custom retryable client instead of built-in", func(t *testing.T) {
163+
attemptCount := 0
164+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
165+
attemptCount++
166+
if attemptCount == 1 {
167+
w.WriteHeader(http.StatusServiceUnavailable)
168+
return
169+
}
170+
w.WriteHeader(http.StatusOK)
171+
w.Write([]byte(`{"result": "success"}`))
172+
}))
173+
defer server.Close()
174+
175+
// Create a custom retryable HTTP client with specific retry configuration
176+
customRetryClient := retryablehttp.NewClient()
177+
customRetryClient.RetryMax = 3
178+
customRetryClient.RetryWaitMax = 10 * time.Millisecond
179+
180+
// Create options with the custom retryable client
181+
opts := defaultHTTPClientOption()
182+
WithRetryableHTTPClint(customRetryClient)(&opts)
183+
184+
// Verify that the custom client is set in options
185+
assert.NotNil(t, opts.retryableHTTPClient)
186+
assert.Equal(t, customRetryClient, opts.retryableHTTPClient)
187+
188+
// Create the common client with custom retryable client
189+
client := newCommonClient(testSystemInfo, opts)
190+
191+
// Make a request that will trigger a retry
192+
req, err := http.NewRequest("GET", server.URL, nil)
193+
require.NoError(t, err)
194+
195+
resp, err := client.do(req)
196+
require.NoError(t, err)
197+
198+
// Should succeed after retry
199+
assert.Equal(t, http.StatusOK, resp.StatusCode)
200+
assert.Equal(t, 2, attemptCount)
201+
202+
// Verify that the client used is the custom one by checking newRetryableHTTPClient
203+
retrievedRetryClient, err := client.newRetryableHTTPClient()
204+
require.NoError(t, err)
205+
assert.Equal(t, customRetryClient, retrievedRetryClient, "should return the custom retryable client")
206+
})
207+
208+
t.Run("respects custom client's retry configuration over built-in defaults", func(t *testing.T) {
209+
attemptCount := 0
210+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
211+
attemptCount++
212+
w.WriteHeader(http.StatusServiceUnavailable)
213+
}))
214+
defer server.Close()
215+
216+
// Create custom client with limited retries
217+
customRetryClient := retryablehttp.NewClient()
218+
customRetryClient.RetryMax = 1 // Only 1 retry (2 total attempts)
219+
customRetryClient.RetryWaitMax = 5 * time.Millisecond
220+
221+
opts := defaultHTTPClientOption()
222+
WithRetryableHTTPClint(customRetryClient)(&opts)
223+
224+
client := newCommonClient(testSystemInfo, opts)
225+
226+
req, err := http.NewRequest("GET", server.URL, nil)
227+
require.NoError(t, err)
228+
229+
resp, err := client.do(req)
230+
// When all retries are exhausted with a retryable error, the client gives up
231+
// and an error is returned
232+
if err != nil {
233+
// Expected: request failed after exhausting retries
234+
assert.Contains(t, err.Error(), "giving up after 2 attempt(s)")
235+
} else {
236+
// Or the final response is returned
237+
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
238+
}
239+
// Should have tried 1 initial + 1 retry = 2 times total
240+
assert.Equal(t, 2, attemptCount)
241+
})
242+
}

0 commit comments

Comments
 (0)