From 15618c8ea8bc71c95f2f94b865c408182af176ea Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Thu, 30 Nov 2017 00:49:37 +0300 Subject: [PATCH 1/2] Use Go context to abort gracefully mirror updates There are two fixes here: 1. Abort package download immediately as ^C is pressed. 2. Import all the already downloaded files into package pool, so that next time mirror is updated, aptly won't download them once again. --- aptly/interfaces.go | 5 ++-- cmd/mirror_update.go | 55 ++++++++++++++++------------------------ context/context.go | 34 +++++++++++++++++++++++++ deb/package.go | 1 + deb/remote.go | 11 ++++---- http/compression.go | 7 ++--- http/compression_test.go | 21 ++++++++------- http/download.go | 26 +++++++++++++++---- http/download_test.go | 27 +++++++++++--------- http/fake.go | 7 ++--- http/temp.go | 9 ++++--- http/temp_test.go | 8 +++--- 12 files changed, 131 insertions(+), 80 deletions(-) diff --git a/aptly/interfaces.go b/aptly/interfaces.go index 3eb5102eb..310b3f4bf 100644 --- a/aptly/interfaces.go +++ b/aptly/interfaces.go @@ -3,6 +3,7 @@ package aptly import ( + "context" "io" "os" @@ -116,9 +117,9 @@ type Progress interface { // Downloader is parallel HTTP fetcher type Downloader interface { // Download starts new download task - Download(url string, destination string) error + Download(ctx context.Context, url string, destination string) error // DownloadWithChecksum starts new download task with checksum verification - DownloadWithChecksum(url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error + DownloadWithChecksum(ctx context.Context, url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error // GetProgress returns Progress object GetProgress() Progress } diff --git a/cmd/mirror_update.go b/cmd/mirror_update.go index 48dcaa22d..75d518a3e 100644 --- a/cmd/mirror_update.go +++ b/cmd/mirror_update.go @@ -2,8 +2,6 @@ package cmd import ( "fmt" - "os" - "os/signal" "strings" "sync" @@ -113,17 +111,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { return fmt.Errorf("unable to update: %s", err) } - // Catch ^C - sigch := make(chan os.Signal) - signal.Notify(sigch, os.Interrupt) - defer signal.Stop(sigch) - - abort := make(chan struct{}) - go func() { - <-sigch - signal.Stop(sigch) - close(abort) - }() + context.GoContextHandleSignals() count := len(queue) context.Progress().Printf("Download queue: %d items (%s)\n", count, utils.HumanBytes(downloadSize)) @@ -148,7 +136,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { for idx := range queue { select { case downloadQueue <- idx: - case <-abort: + case <-context.GoContext().Done(): return } } @@ -181,6 +169,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { // download file... e = context.Downloader().DownloadWithChecksum( + context.GoContext(), repo.PackageURL(task.File.DownloadURL()).String(), task.TempDownPath, &task.File.Checksums, @@ -190,28 +179,20 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { pushError(e) continue } - case <-abort: + + task.Done = true + case <-context.GoContext().Done(): return } } }() } - // Wait for all downloads to finish + // Wait for all download goroutines to finish wg.Wait() - select { - case <-abort: - return fmt.Errorf("unable to update: interrupted") - default: - } - context.Progress().ShutdownBar() - if len(errors) > 0 { - return fmt.Errorf("unable to update: download errors:\n %s", strings.Join(errors, "\n ")) - } - err = context.ReOpenDatabase() if err != nil { return fmt.Errorf("unable to update: %s", err) @@ -221,11 +202,15 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { context.Progress().InitBar(int64(len(queue)), false) for idx := range queue { - context.Progress().AddBar(1) task := &queue[idx] + if !task.Done { + // download not finished yet + continue + } + // and import it back to the pool task.File.PoolPath, err = context.PackagePool().Import(task.TempDownPath, task.File.Filename, &task.File.Checksums, true, context.CollectionFactory().ChecksumCollection()) if err != nil { @@ -237,16 +222,20 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { additionalTask.File.PoolPath = task.File.PoolPath additionalTask.File.Checksums = task.File.Checksums } - - select { - case <-abort: - return fmt.Errorf("unable to update: interrupted") - default: - } } context.Progress().ShutdownBar() + select { + case <-context.GoContext().Done(): + return fmt.Errorf("unable to update: interrupted") + default: + } + + if len(errors) > 0 { + return fmt.Errorf("unable to update: download errors:\n %s", strings.Join(errors, "\n ")) + } + repo.FinalizeDownload(context.CollectionFactory(), context.Progress()) err = context.CollectionFactory().RemoteRepoCollection().Update(repo) if err != nil { diff --git a/context/context.go b/context/context.go index 3adeabe33..6af1a0540 100644 --- a/context/context.go +++ b/context/context.go @@ -2,9 +2,11 @@ package context import ( + gocontext "context" "fmt" "math/rand" "os" + "os/signal" "path/filepath" "runtime" "runtime/pprof" @@ -30,6 +32,8 @@ import ( type AptlyContext struct { sync.Mutex + ctx gocontext.Context + flags, globalFlags *flag.FlagSet configLoaded bool @@ -438,6 +442,35 @@ func (context *AptlyContext) GlobalFlags() *flag.FlagSet { return context.globalFlags } +// GoContext returns instance of Go context.Context for the current session +func (context *AptlyContext) GoContext() gocontext.Context { + context.Lock() + defer context.Unlock() + + return context.ctx +} + +// GoContextHandleSignals upgrades context to handle ^C by aborting context +func (context *AptlyContext) GoContextHandleSignals() { + context.Lock() + defer context.Unlock() + + // Catch ^C + sigch := make(chan os.Signal) + signal.Notify(sigch, os.Interrupt) + + var cancel gocontext.CancelFunc + + context.ctx, cancel = gocontext.WithCancel(context.ctx) + + go func() { + <-sigch + signal.Stop(sigch) + context.Progress().PrintfStdErr("Aborting... press ^C once again to abort immediately\n") + cancel() + }() +} + // Shutdown shuts context down func (context *AptlyContext) Shutdown() { context.Lock() @@ -494,6 +527,7 @@ func NewContext(flags *flag.FlagSet) (*AptlyContext, error) { flags: flags, globalFlags: flags, dependencyOptions: -1, + ctx: gocontext.TODO(), publishedStorages: map[string]aptly.PublishedStorage{}, } diff --git a/deb/package.go b/deb/package.go index 996aef457..1034bdb34 100644 --- a/deb/package.go +++ b/deb/package.go @@ -622,6 +622,7 @@ type PackageDownloadTask struct { File *PackageFile Additional []PackageDownloadTask TempDownPath string + Done bool } // DownloadList returns list of missing package files for download in format diff --git a/deb/remote.go b/deb/remote.go index 530d02c73..bddeba3a9 100644 --- a/deb/remote.go +++ b/deb/remote.go @@ -2,6 +2,7 @@ package deb import ( "bytes" + gocontext "context" "fmt" "log" "net/url" @@ -258,13 +259,13 @@ func (repo *RemoteRepo) Fetch(d aptly.Downloader, verifier pgp.Verifier) error { if verifier == nil { // 0. Just download release file to temporary URL - release, err = http.DownloadTemp(d, repo.ReleaseURL("Release").String()) + release, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release").String()) if err != nil { return err } } else { // 1. try InRelease file - inrelease, err = http.DownloadTemp(d, repo.ReleaseURL("InRelease").String()) + inrelease, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("InRelease").String()) if err != nil { goto splitsignature } @@ -286,12 +287,12 @@ func (repo *RemoteRepo) Fetch(d aptly.Downloader, verifier pgp.Verifier) error { splitsignature: // 2. try Release + Release.gpg - release, err = http.DownloadTemp(d, repo.ReleaseURL("Release").String()) + release, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release").String()) if err != nil { return err } - releasesig, err = http.DownloadTemp(d, repo.ReleaseURL("Release.gpg").String()) + releasesig, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release.gpg").String()) if err != nil { return err } @@ -439,7 +440,7 @@ func (repo *RemoteRepo) DownloadPackageIndexes(progress aptly.Progress, d aptly. for _, info := range packagesPaths { path, kind := info[0], info[1] - packagesReader, packagesFile, err := http.DownloadTryCompression(d, repo.IndexesRootURL(), path, repo.ReleaseFiles, ignoreMismatch, maxTries) + packagesReader, packagesFile, err := http.DownloadTryCompression(gocontext.TODO(), d, repo.IndexesRootURL(), path, repo.ReleaseFiles, ignoreMismatch, maxTries) if err != nil { return err } diff --git a/http/compression.go b/http/compression.go index b0eab5565..1a68f1c75 100644 --- a/http/compression.go +++ b/http/compression.go @@ -3,6 +3,7 @@ package http import ( "compress/bzip2" "compress/gzip" + "context" "fmt" "io" "net/url" @@ -39,7 +40,7 @@ var compressionMethods = []struct { // DownloadTryCompression tries to download from URL .bz2, .gz and raw extension until // it finds existing file. -func DownloadTryCompression(downloader aptly.Downloader, baseURL *url.URL, path string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (io.Reader, *os.File, error) { +func DownloadTryCompression(ctx context.Context, downloader aptly.Downloader, baseURL *url.URL, path string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (io.Reader, *os.File, error) { var err error for _, method := range compressionMethods { @@ -63,13 +64,13 @@ func DownloadTryCompression(downloader aptly.Downloader, baseURL *url.URL, path if foundChecksum { expected := expectedChecksums[bestSuffix] - file, err = DownloadTempWithChecksum(downloader, tryURL.String(), &expected, ignoreMismatch, maxTries) + file, err = DownloadTempWithChecksum(ctx, downloader, tryURL.String(), &expected, ignoreMismatch, maxTries) } else { if !ignoreMismatch { continue } - file, err = DownloadTemp(downloader, tryURL.String()) + file, err = DownloadTemp(ctx, downloader, tryURL.String()) } if err != nil { diff --git a/http/compression_test.go b/http/compression_test.go index 5dacc0dcf..9f1135f5b 100644 --- a/http/compression_test.go +++ b/http/compression_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "errors" "io" "net/url" @@ -12,6 +13,7 @@ import ( type CompressionSuite struct { baseURL *url.URL + ctx context.Context } var _ = Suite(&CompressionSuite{}) @@ -25,6 +27,7 @@ const ( func (s *CompressionSuite) SetUpTest(c *C) { s.baseURL, _ = url.Parse("http://example.com/") + s.ctx = context.Background() } func (s *CompressionSuite) TestDownloadTryCompression(c *C) { @@ -41,7 +44,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) { buf = make([]byte, 4) d := NewFakeDownloader() d.ExpectResponse("http://example.com/file.bz2", bzipData) - r, file, err := DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1) + r, file, err := DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1) c.Assert(err, IsNil) defer file.Close() io.ReadFull(r, buf) @@ -53,7 +56,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) { d = NewFakeDownloader() d.ExpectError("http://example.com/file.bz2", &Error{Code: 404}) d.ExpectResponse("http://example.com/file.gz", gzipData) - r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1) + r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1) c.Assert(err, IsNil) defer file.Close() io.ReadFull(r, buf) @@ -66,7 +69,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) { d.ExpectError("http://example.com/file.bz2", &Error{Code: 404}) d.ExpectError("http://example.com/file.gz", &Error{Code: 404}) d.ExpectResponse("http://example.com/file.xz", xzData) - r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1) + r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1) c.Assert(err, IsNil) defer file.Close() io.ReadFull(r, buf) @@ -80,7 +83,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) { d.ExpectError("http://example.com/file.gz", &Error{Code: 404}) d.ExpectError("http://example.com/file.xz", &Error{Code: 404}) d.ExpectResponse("http://example.com/file", rawData) - r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1) + r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1) c.Assert(err, IsNil) defer file.Close() io.ReadFull(r, buf) @@ -91,7 +94,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) { d = NewFakeDownloader() d.ExpectError("http://example.com/file.bz2", &Error{Code: 404}) d.ExpectResponse("http://example.com/file.gz", "x") - _, _, err = DownloadTryCompression(d, s.baseURL, "file", nil, true, 1) + _, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1) c.Assert(err, ErrorMatches, "unexpected EOF") c.Assert(d.Empty(), Equals, true) } @@ -109,7 +112,7 @@ func (s *CompressionSuite) TestDownloadTryCompressionLongestSuffix(c *C) { buf = make([]byte, 4) d := NewFakeDownloader() d.ExpectResponse("http://example.com/subdir/file.bz2", bzipData) - r, file, err := DownloadTryCompression(d, s.baseURL, "subdir/file", expectedChecksums, false, 1) + r, file, err := DownloadTryCompression(s.ctx, d, s.baseURL, "subdir/file", expectedChecksums, false, 1) c.Assert(err, IsNil) defer file.Close() io.ReadFull(r, buf) @@ -119,7 +122,7 @@ func (s *CompressionSuite) TestDownloadTryCompressionLongestSuffix(c *C) { func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) { d := NewFakeDownloader() - _, _, err := DownloadTryCompression(d, s.baseURL, "file", nil, true, 1) + _, _, err := DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1) c.Assert(err, ErrorMatches, "unexpected request.*") d = NewFakeDownloader() @@ -127,7 +130,7 @@ func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) { d.ExpectError("http://example.com/file.gz", &Error{Code: 404}) d.ExpectError("http://example.com/file.xz", &Error{Code: 404}) d.ExpectError("http://example.com/file", errors.New("403")) - _, _, err = DownloadTryCompression(d, s.baseURL, "file", nil, true, 1) + _, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1) c.Assert(err, ErrorMatches, "403") d = NewFakeDownloader() @@ -141,6 +144,6 @@ func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) { "file.xz": {Size: 7}, "file": {Size: 7}, } - _, _, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1) + _, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1) c.Assert(err, ErrorMatches, "checksums don't match.*") } diff --git a/http/download.go b/http/download.go index 07a51009e..d41dfab67 100644 --- a/http/download.go +++ b/http/download.go @@ -1,12 +1,15 @@ package http import ( + "context" "fmt" "io" + "net" "net/http" "os" "path/filepath" "strings" + "syscall" "time" "github.com/mxk/go-flowrate/flowrate" @@ -62,12 +65,24 @@ func (downloader *downloaderImpl) GetProgress() aptly.Progress { } // Download starts new download task -func (downloader *downloaderImpl) Download(url string, destination string) error { - return downloader.DownloadWithChecksum(url, destination, nil, false, 1) +func (downloader *downloaderImpl) Download(ctx context.Context, url string, destination string) error { + return downloader.DownloadWithChecksum(ctx, url, destination, nil, false, 1) +} + +func retryableError(err error) bool { + switch err.(type) { + case net.Error: + return true + case *net.OpError: + return true + case syscall.Errno: + return true + } + return false } // DownloadWithChecksum starts new download task with checksum verification -func (downloader *downloaderImpl) DownloadWithChecksum(url string, destination string, +func (downloader *downloaderImpl) DownloadWithChecksum(ctx context.Context, url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error { downloader.progress.Printf("Downloading %s...\n", url) @@ -77,6 +92,7 @@ func (downloader *downloaderImpl) DownloadWithChecksum(url string, destination s return errors.Wrap(err, url) } req.Close = true + req = req.WithContext(ctx) proxyURL, _ := downloader.client.Transport.(*http.Transport).Proxy(req) if proxyURL == nil && (req.URL.Scheme == "http" || req.URL.Scheme == "https") { @@ -88,10 +104,10 @@ func (downloader *downloaderImpl) DownloadWithChecksum(url string, destination s for maxTries > 0 { temppath, err = downloader.download(req, url, destination, expected, ignoreMismatch) - if err != nil { + if err != nil && retryableError(err) { maxTries-- } else { - // successful download + // get out of the loop break } } diff --git a/http/download_test.go b/http/download_test.go index e11d9ea2d..3e6166250 100644 --- a/http/download_test.go +++ b/http/download_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "fmt" "io/ioutil" "net" @@ -21,6 +22,7 @@ type DownloaderSuiteBase struct { ch chan struct{} progress aptly.Progress d aptly.Downloader + ctx context.Context } func (s *DownloaderSuiteBase) SetUpTest(c *C) { @@ -44,6 +46,7 @@ func (s *DownloaderSuiteBase) SetUpTest(c *C) { s.progress.Start() s.d = NewDownloader(0, s.progress) + s.ctx = context.Background() } func (s *DownloaderSuiteBase) TearDownTest(c *C) { @@ -71,52 +74,52 @@ func (s *DownloaderSuite) TearDownTest(c *C) { } func (s *DownloaderSuite) TestDownloadOK(c *C) { - c.Assert(s.d.Download(s.url+"/test", s.tempfile.Name()), IsNil) + c.Assert(s.d.Download(s.ctx, s.url+"/test", s.tempfile.Name()), IsNil) } func (s *DownloaderSuite) TestDownloadWithChecksum(c *C) { - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{}, false, 1), + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{}, false, 1), ErrorMatches, ".*size check mismatch 12 != 0") - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "abcdef"}, false, 1), + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "abcdef"}, false, 1), ErrorMatches, ".*md5 hash mismatch \"a1acb0fe91c7db45ec4d775192ec5738\" != \"abcdef\"") - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "abcdef"}, true, 1), + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "abcdef"}, true, 1), IsNil) - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738"}, false, 1), + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738"}, false, 1), IsNil) - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", SHA1: "abcdef"}, false, 1), + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", SHA1: "abcdef"}, false, 1), ErrorMatches, ".*sha1 hash mismatch \"921893bae6ad6fd818401875d6779254ef0ff0ec\" != \"abcdef\"") - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", SHA1: "921893bae6ad6fd818401875d6779254ef0ff0ec"}, false, 1), IsNil) - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", SHA1: "921893bae6ad6fd818401875d6779254ef0ff0ec", SHA256: "abcdef"}, false, 1), ErrorMatches, ".*sha256 hash mismatch \"b3c92ee1246176ed35f6e8463cd49074f29442f5bbffc3f8591cde1dcc849dac\" != \"abcdef\"") checksums := utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", SHA1: "921893bae6ad6fd818401875d6779254ef0ff0ec", SHA256: "b3c92ee1246176ed35f6e8463cd49074f29442f5bbffc3f8591cde1dcc849dac"} - c.Assert(s.d.DownloadWithChecksum(s.url+"/test", s.tempfile.Name(), &checksums, false, 1), + c.Assert(s.d.DownloadWithChecksum(s.ctx, s.url+"/test", s.tempfile.Name(), &checksums, false, 1), IsNil) // download backfills missing checksums c.Check(checksums.SHA512, Equals, "bac18bf4e564856369acc2ed57300fecba3a2c1af5ae8304021e4252488678feb18118466382ee4e1210fe1f065080210e453a80cfb37ccb8752af3269df160e") } func (s *DownloaderSuite) TestDownload404(c *C) { - c.Assert(s.d.Download(s.url+"/doesntexist", s.tempfile.Name()), + c.Assert(s.d.Download(s.ctx, s.url+"/doesntexist", s.tempfile.Name()), ErrorMatches, "HTTP code 404.*") } func (s *DownloaderSuite) TestDownloadConnectError(c *C) { - c.Assert(s.d.Download("http://nosuch.localhost/", s.tempfile.Name()), + c.Assert(s.d.Download(s.ctx, "http://nosuch.localhost/", s.tempfile.Name()), ErrorMatches, ".*no such host") } func (s *DownloaderSuite) TestDownloadFileError(c *C) { - c.Assert(s.d.Download(s.url+"/test", "/"), + c.Assert(s.d.Download(s.ctx, s.url+"/test", "/"), ErrorMatches, ".*permission denied") } diff --git a/http/fake.go b/http/fake.go index 19bf0af49..ec713cc6a 100644 --- a/http/fake.go +++ b/http/fake.go @@ -1,6 +1,7 @@ package http import ( + "context" "fmt" "io" "os" @@ -60,7 +61,7 @@ func (f *FakeDownloader) Empty() bool { } // DownloadWithChecksum performs fake download by matching against first expectation in the queue or any expectation, with cheksum verification -func (f *FakeDownloader) DownloadWithChecksum(url string, filename string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error { +func (f *FakeDownloader) DownloadWithChecksum(ctx context.Context, url string, filename string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error { var expectation expectedRequest if len(f.expected) > 0 && f.expected[0].URL == url { expectation, f.expected = f.expected[0], f.expected[1:] @@ -109,8 +110,8 @@ func (f *FakeDownloader) DownloadWithChecksum(url string, filename string, expec } // Download performs fake download by matching against first expectation in the queue -func (f *FakeDownloader) Download(url string, filename string) error { - return f.DownloadWithChecksum(url, filename, nil, false, 1) +func (f *FakeDownloader) Download(ctx context.Context, url string, filename string) error { + return f.DownloadWithChecksum(ctx, url, filename, nil, false, 1) } // GetProgress returns Progress object diff --git a/http/temp.go b/http/temp.go index bcd1367ac..13aff164b 100644 --- a/http/temp.go +++ b/http/temp.go @@ -1,6 +1,7 @@ package http import ( + "context" "io/ioutil" "os" "path/filepath" @@ -12,14 +13,14 @@ import ( // DownloadTemp starts new download to temporary file and returns File // // Temporary file would be already removed, so no need to cleanup -func DownloadTemp(downloader aptly.Downloader, url string) (*os.File, error) { - return DownloadTempWithChecksum(downloader, url, nil, false, 1) +func DownloadTemp(ctx context.Context, downloader aptly.Downloader, url string) (*os.File, error) { + return DownloadTempWithChecksum(ctx, downloader, url, nil, false, 1) } // DownloadTempWithChecksum is a DownloadTemp with checksum verification // // Temporary file would be already removed, so no need to cleanup -func DownloadTempWithChecksum(downloader aptly.Downloader, url string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (*os.File, error) { +func DownloadTempWithChecksum(ctx context.Context, downloader aptly.Downloader, url string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (*os.File, error) { tempdir, err := ioutil.TempDir(os.TempDir(), "aptly") if err != nil { return nil, err @@ -33,7 +34,7 @@ func DownloadTempWithChecksum(downloader aptly.Downloader, url string, expected defer downloader.GetProgress().ShutdownBar() } - err = downloader.DownloadWithChecksum(url, tempfile, expected, ignoreMismatch, maxTries) + err = downloader.DownloadWithChecksum(ctx, url, tempfile, expected, ignoreMismatch, maxTries) if err != nil { return nil, err } diff --git a/http/temp_test.go b/http/temp_test.go index c4723e2e4..20d822220 100644 --- a/http/temp_test.go +++ b/http/temp_test.go @@ -23,7 +23,7 @@ func (s *TempSuite) TearDownTest(c *C) { } func (s *TempSuite) TestDownloadTemp(c *C) { - f, err := DownloadTemp(s.d, s.url+"/test") + f, err := DownloadTemp(s.ctx, s.d, s.url+"/test") c.Assert(err, IsNil) defer f.Close() @@ -37,18 +37,18 @@ func (s *TempSuite) TestDownloadTemp(c *C) { } func (s *TempSuite) TestDownloadTempWithChecksum(c *C) { - f, err := DownloadTempWithChecksum(s.d, s.url+"/test", &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", + f, err := DownloadTempWithChecksum(s.ctx, s.d, s.url+"/test", &utils.ChecksumInfo{Size: 12, MD5: "a1acb0fe91c7db45ec4d775192ec5738", SHA1: "921893bae6ad6fd818401875d6779254ef0ff0ec", SHA256: "b3c92ee1246176ed35f6e8463cd49074f29442f5bbffc3f8591cde1dcc849dac"}, false, 1) c.Assert(err, IsNil) c.Assert(f.Close(), IsNil) - _, err = DownloadTempWithChecksum(s.d, s.url+"/test", &utils.ChecksumInfo{Size: 13}, false, 1) + _, err = DownloadTempWithChecksum(s.ctx, s.d, s.url+"/test", &utils.ChecksumInfo{Size: 13}, false, 1) c.Assert(err, ErrorMatches, ".*size check mismatch 12 != 13") } func (s *TempSuite) TestDownloadTempError(c *C) { - f, err := DownloadTemp(s.d, s.url+"/doesntexist") + f, err := DownloadTemp(s.ctx, s.d, s.url+"/doesntexist") c.Assert(err, NotNil) c.Assert(f, IsNil) c.Assert(err, ErrorMatches, "HTTP code 404.*") From b7490fe909fef1233dbe9618b11c254dfdc9cf23 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Thu, 30 Nov 2017 23:44:04 +0300 Subject: [PATCH 2/2] Refactor to embed `gocontext.Context` into aptly `context` --- cmd/mirror_update.go | 8 ++++---- context/context.go | 14 +++----------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/cmd/mirror_update.go b/cmd/mirror_update.go index 75d518a3e..003926075 100644 --- a/cmd/mirror_update.go +++ b/cmd/mirror_update.go @@ -136,7 +136,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { for idx := range queue { select { case downloadQueue <- idx: - case <-context.GoContext().Done(): + case <-context.Done(): return } } @@ -169,7 +169,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { // download file... e = context.Downloader().DownloadWithChecksum( - context.GoContext(), + context, repo.PackageURL(task.File.DownloadURL()).String(), task.TempDownPath, &task.File.Checksums, @@ -181,7 +181,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { } task.Done = true - case <-context.GoContext().Done(): + case <-context.Done(): return } } @@ -227,7 +227,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error { context.Progress().ShutdownBar() select { - case <-context.GoContext().Done(): + case <-context.Done(): return fmt.Errorf("unable to update: interrupted") default: } diff --git a/context/context.go b/context/context.go index 6af1a0540..b59d5b511 100644 --- a/context/context.go +++ b/context/context.go @@ -32,7 +32,7 @@ import ( type AptlyContext struct { sync.Mutex - ctx gocontext.Context + gocontext.Context flags, globalFlags *flag.FlagSet configLoaded bool @@ -442,14 +442,6 @@ func (context *AptlyContext) GlobalFlags() *flag.FlagSet { return context.globalFlags } -// GoContext returns instance of Go context.Context for the current session -func (context *AptlyContext) GoContext() gocontext.Context { - context.Lock() - defer context.Unlock() - - return context.ctx -} - // GoContextHandleSignals upgrades context to handle ^C by aborting context func (context *AptlyContext) GoContextHandleSignals() { context.Lock() @@ -461,7 +453,7 @@ func (context *AptlyContext) GoContextHandleSignals() { var cancel gocontext.CancelFunc - context.ctx, cancel = gocontext.WithCancel(context.ctx) + context.Context, cancel = gocontext.WithCancel(context.Context) go func() { <-sigch @@ -527,7 +519,7 @@ func NewContext(flags *flag.FlagSet) (*AptlyContext, error) { flags: flags, globalFlags: flags, dependencyOptions: -1, - ctx: gocontext.TODO(), + Context: gocontext.TODO(), publishedStorages: map[string]aptly.PublishedStorage{}, }