Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Go context to abort gracefully mirror updates #683

Merged
merged 2 commits into from
Nov 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions aptly/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package aptly

import (
"context"
"io"
"os"

Expand Down Expand Up @@ -116,9 +117,9 @@ type Progress interface {
// Downloader is parallel HTTP fetcher
type Downloader interface {
// Download starts new download task
Download(url string, destination string) error
Download(ctx context.Context, url string, destination string) error
// DownloadWithChecksum starts new download task with checksum verification
DownloadWithChecksum(url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error
DownloadWithChecksum(ctx context.Context, url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error
// GetProgress returns Progress object
GetProgress() Progress
}
Expand Down
55 changes: 22 additions & 33 deletions cmd/mirror_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package cmd

import (
"fmt"
"os"
"os/signal"
"strings"
"sync"

Expand Down Expand Up @@ -113,17 +111,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
return fmt.Errorf("unable to update: %s", err)
}

// Catch ^C
sigch := make(chan os.Signal)
signal.Notify(sigch, os.Interrupt)
defer signal.Stop(sigch)

abort := make(chan struct{})
go func() {
<-sigch
signal.Stop(sigch)
close(abort)
}()
context.GoContextHandleSignals()

count := len(queue)
context.Progress().Printf("Download queue: %d items (%s)\n", count, utils.HumanBytes(downloadSize))
Expand All @@ -148,7 +136,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
for idx := range queue {
select {
case downloadQueue <- idx:
case <-abort:
case <-context.Done():
return
}
}
Expand Down Expand Up @@ -181,6 +169,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {

// download file...
e = context.Downloader().DownloadWithChecksum(
context,
repo.PackageURL(task.File.DownloadURL()).String(),
task.TempDownPath,
&task.File.Checksums,
Expand All @@ -190,28 +179,20 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
pushError(e)
continue
}
case <-abort:

task.Done = true
case <-context.Done():
return
}
}
}()
}

// Wait for all downloads to finish
// Wait for all download goroutines to finish
wg.Wait()

select {
case <-abort:
return fmt.Errorf("unable to update: interrupted")
default:
}

context.Progress().ShutdownBar()

if len(errors) > 0 {
return fmt.Errorf("unable to update: download errors:\n %s", strings.Join(errors, "\n "))
}

err = context.ReOpenDatabase()
if err != nil {
return fmt.Errorf("unable to update: %s", err)
Expand All @@ -221,11 +202,15 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
context.Progress().InitBar(int64(len(queue)), false)

for idx := range queue {

context.Progress().AddBar(1)

task := &queue[idx]

if !task.Done {
// download not finished yet
continue
}

// and import it back to the pool
task.File.PoolPath, err = context.PackagePool().Import(task.TempDownPath, task.File.Filename, &task.File.Checksums, true, context.CollectionFactory().ChecksumCollection())
if err != nil {
Expand All @@ -237,16 +222,20 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
additionalTask.File.PoolPath = task.File.PoolPath
additionalTask.File.Checksums = task.File.Checksums
}

select {
case <-abort:
return fmt.Errorf("unable to update: interrupted")
default:
}
}

context.Progress().ShutdownBar()

select {
case <-context.Done():
return fmt.Errorf("unable to update: interrupted")
default:
}

if len(errors) > 0 {
return fmt.Errorf("unable to update: download errors:\n %s", strings.Join(errors, "\n "))
}

repo.FinalizeDownload(context.CollectionFactory(), context.Progress())
err = context.CollectionFactory().RemoteRepoCollection().Update(repo)
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
package context

import (
gocontext "context"
"fmt"
"math/rand"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
Expand All @@ -30,6 +32,8 @@ import (
type AptlyContext struct {
sync.Mutex

gocontext.Context

flags, globalFlags *flag.FlagSet
configLoaded bool

Expand Down Expand Up @@ -438,6 +442,27 @@ func (context *AptlyContext) GlobalFlags() *flag.FlagSet {
return context.globalFlags
}

// GoContextHandleSignals upgrades context to handle ^C by aborting context
func (context *AptlyContext) GoContextHandleSignals() {
context.Lock()
defer context.Unlock()

// Catch ^C
sigch := make(chan os.Signal)
signal.Notify(sigch, os.Interrupt)

var cancel gocontext.CancelFunc

context.Context, cancel = gocontext.WithCancel(context.Context)

go func() {
<-sigch
signal.Stop(sigch)
context.Progress().PrintfStdErr("Aborting... press ^C once again to abort immediately\n")
cancel()
}()
}

// Shutdown shuts context down
func (context *AptlyContext) Shutdown() {
context.Lock()
Expand Down Expand Up @@ -494,6 +519,7 @@ func NewContext(flags *flag.FlagSet) (*AptlyContext, error) {
flags: flags,
globalFlags: flags,
dependencyOptions: -1,
Context: gocontext.TODO(),
publishedStorages: map[string]aptly.PublishedStorage{},
}

Expand Down
1 change: 1 addition & 0 deletions deb/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ type PackageDownloadTask struct {
File *PackageFile
Additional []PackageDownloadTask
TempDownPath string
Done bool
}

// DownloadList returns list of missing package files for download in format
Expand Down
11 changes: 6 additions & 5 deletions deb/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package deb

import (
"bytes"
gocontext "context"
"fmt"
"log"
"net/url"
Expand Down Expand Up @@ -258,13 +259,13 @@ func (repo *RemoteRepo) Fetch(d aptly.Downloader, verifier pgp.Verifier) error {

if verifier == nil {
// 0. Just download release file to temporary URL
release, err = http.DownloadTemp(d, repo.ReleaseURL("Release").String())
release, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release").String())
if err != nil {
return err
}
} else {
// 1. try InRelease file
inrelease, err = http.DownloadTemp(d, repo.ReleaseURL("InRelease").String())
inrelease, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("InRelease").String())
if err != nil {
goto splitsignature
}
Expand All @@ -286,12 +287,12 @@ func (repo *RemoteRepo) Fetch(d aptly.Downloader, verifier pgp.Verifier) error {

splitsignature:
// 2. try Release + Release.gpg
release, err = http.DownloadTemp(d, repo.ReleaseURL("Release").String())
release, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release").String())
if err != nil {
return err
}

releasesig, err = http.DownloadTemp(d, repo.ReleaseURL("Release.gpg").String())
releasesig, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release.gpg").String())
if err != nil {
return err
}
Expand Down Expand Up @@ -439,7 +440,7 @@ func (repo *RemoteRepo) DownloadPackageIndexes(progress aptly.Progress, d aptly.

for _, info := range packagesPaths {
path, kind := info[0], info[1]
packagesReader, packagesFile, err := http.DownloadTryCompression(d, repo.IndexesRootURL(), path, repo.ReleaseFiles, ignoreMismatch, maxTries)
packagesReader, packagesFile, err := http.DownloadTryCompression(gocontext.TODO(), d, repo.IndexesRootURL(), path, repo.ReleaseFiles, ignoreMismatch, maxTries)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions http/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
import (
"compress/bzip2"
"compress/gzip"
"context"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -39,7 +40,7 @@ var compressionMethods = []struct {

// DownloadTryCompression tries to download from URL .bz2, .gz and raw extension until
// it finds existing file.
func DownloadTryCompression(downloader aptly.Downloader, baseURL *url.URL, path string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (io.Reader, *os.File, error) {
func DownloadTryCompression(ctx context.Context, downloader aptly.Downloader, baseURL *url.URL, path string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (io.Reader, *os.File, error) {
var err error

for _, method := range compressionMethods {
Expand All @@ -63,13 +64,13 @@ func DownloadTryCompression(downloader aptly.Downloader, baseURL *url.URL, path

if foundChecksum {
expected := expectedChecksums[bestSuffix]
file, err = DownloadTempWithChecksum(downloader, tryURL.String(), &expected, ignoreMismatch, maxTries)
file, err = DownloadTempWithChecksum(ctx, downloader, tryURL.String(), &expected, ignoreMismatch, maxTries)
} else {
if !ignoreMismatch {
continue
}

file, err = DownloadTemp(downloader, tryURL.String())
file, err = DownloadTemp(ctx, downloader, tryURL.String())
}

if err != nil {
Expand Down
21 changes: 12 additions & 9 deletions http/compression_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"context"
"errors"
"io"
"net/url"
Expand All @@ -12,6 +13,7 @@ import (

type CompressionSuite struct {
baseURL *url.URL
ctx context.Context
}

var _ = Suite(&CompressionSuite{})
Expand All @@ -25,6 +27,7 @@ const (

func (s *CompressionSuite) SetUpTest(c *C) {
s.baseURL, _ = url.Parse("http://example.com/")
s.ctx = context.Background()
}

func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
Expand All @@ -41,7 +44,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
buf = make([]byte, 4)
d := NewFakeDownloader()
d.ExpectResponse("http://example.com/file.bz2", bzipData)
r, file, err := DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err := DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -53,7 +56,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d = NewFakeDownloader()
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectResponse("http://example.com/file.gz", gzipData)
r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -66,7 +69,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectError("http://example.com/file.gz", &Error{Code: 404})
d.ExpectResponse("http://example.com/file.xz", xzData)
r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -80,7 +83,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d.ExpectError("http://example.com/file.gz", &Error{Code: 404})
d.ExpectError("http://example.com/file.xz", &Error{Code: 404})
d.ExpectResponse("http://example.com/file", rawData)
r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -91,7 +94,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d = NewFakeDownloader()
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectResponse("http://example.com/file.gz", "x")
_, _, err = DownloadTryCompression(d, s.baseURL, "file", nil, true, 1)
_, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1)
c.Assert(err, ErrorMatches, "unexpected EOF")
c.Assert(d.Empty(), Equals, true)
}
Expand All @@ -109,7 +112,7 @@ func (s *CompressionSuite) TestDownloadTryCompressionLongestSuffix(c *C) {
buf = make([]byte, 4)
d := NewFakeDownloader()
d.ExpectResponse("http://example.com/subdir/file.bz2", bzipData)
r, file, err := DownloadTryCompression(d, s.baseURL, "subdir/file", expectedChecksums, false, 1)
r, file, err := DownloadTryCompression(s.ctx, d, s.baseURL, "subdir/file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -119,15 +122,15 @@ func (s *CompressionSuite) TestDownloadTryCompressionLongestSuffix(c *C) {

func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) {
d := NewFakeDownloader()
_, _, err := DownloadTryCompression(d, s.baseURL, "file", nil, true, 1)
_, _, err := DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1)
c.Assert(err, ErrorMatches, "unexpected request.*")

d = NewFakeDownloader()
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectError("http://example.com/file.gz", &Error{Code: 404})
d.ExpectError("http://example.com/file.xz", &Error{Code: 404})
d.ExpectError("http://example.com/file", errors.New("403"))
_, _, err = DownloadTryCompression(d, s.baseURL, "file", nil, true, 1)
_, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1)
c.Assert(err, ErrorMatches, "403")

d = NewFakeDownloader()
Expand All @@ -141,6 +144,6 @@ func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) {
"file.xz": {Size: 7},
"file": {Size: 7},
}
_, _, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
_, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, ErrorMatches, "checksums don't match.*")
}
Loading