From 8339301726cc9144b0a1d8a54d32de993c6af03b Mon Sep 17 00:00:00 2001 From: Liam Galvin Date: Thu, 1 Aug 2024 20:14:17 +0100 Subject: [PATCH] add tests --- get_s3.go | 4 ++-- get_s3_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/get_s3.go b/get_s3.go index 346b98a0b..b478bde4e 100644 --- a/get_s3.go +++ b/get_s3.go @@ -276,7 +276,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c path = pathParts[2] // vhost-style, dash region indication case 4: - // Parse the region out of the first part of the host + // Parse the region out of the second part of the host region = strings.TrimPrefix(strings.TrimPrefix(hostParts[1], "s3-"), "s3") if region == "" { err = fmt.Errorf("URL is not a valid S3 URL") @@ -293,7 +293,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c case 5: region = hostParts[2] pathParts := strings.SplitN(u.Path, "/", 2) - if len(pathParts) < 3 { + if len(pathParts) < 2 { err = fmt.Errorf("URL is not a valid S3 URL") return } diff --git a/get_s3_test.go b/get_s3_test.go index 7b2425404..18187f149 100644 --- a/get_s3_test.go +++ b/get_s3_test.go @@ -165,12 +165,13 @@ func TestS3Getter_ClientMode_collision(t *testing.T) { func TestS3Getter_Url(t *testing.T) { var s3tests = []struct { - name string - url string - region string - bucket string - path string - version string + name string + url string + region string + bucket string + path string + version string + expectedErr string }{ { name: "AWSv1234", @@ -220,6 +221,11 @@ func TestS3Getter_Url(t *testing.T) { path: "hello.txt", version: "", }, + { + name: "malformed s3 url", + url: "s3::https://s3.amazonaws.com/bucket", + expectedErr: "URL is not a valid S3 URL", + }, } for i, pt := range s3tests { @@ -238,7 +244,15 @@ func TestS3Getter_Url(t *testing.T) { region, bucket, path, version, creds, err := g.parseUrl(u) if err != nil { - t.Fatalf("err: %s", err) + if pt.expectedErr == "" { + t.Fatalf("err: %s", err) + } + if err.Error() != pt.expectedErr { + t.Fatalf("expected %s, got %s", pt.expectedErr, err.Error()) + } + return + } else if pt.expectedErr != "" { + t.Fatalf("expected error, got none") } if region != pt.region { t.Fatalf("expected %s, got %s", pt.region, region) @@ -258,3 +272,40 @@ func TestS3Getter_Url(t *testing.T) { }) } } + +func Test_S3Getter_ParseUrl_Malformed(t *testing.T) { + tests := []struct { + name string + url string + }{ + { + name: "path style", + url: "https://s3.amazonaws.com/bucket", + }, + { + name: "vhost-style, dash region indication", + url: "https://bucket.s3-us-east-1.amazonaws.com", + }, + { + name: "vhost-style, dot region indication", + url: "https://bucket.s3.us-east-1.amazonaws.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := new(S3Getter) + u, err := url.Parse(tt.url) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + _, _, _, _, _, err = g.parseUrl(u) + if err == nil { + t.Fatalf("expected error, got none") + } + if err.Error() != "URL is not a valid S3 URL" { + t.Fatalf("expected error 'URL is not a valid S3 URL', got %s", err.Error()) + } + }) + } + +}