aptly/http/fake.go

140 lines
3.9 KiB
Go

package http
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"github.com/aptly-dev/aptly/aptly"
"github.com/aptly-dev/aptly/utils"
)
type expectedRequest struct {
URL string
Err error
Response string
}
// FakeDownloader is like Downloader, but it used in tests
// to stub out results
type FakeDownloader struct {
expected []expectedRequest
anyExpected map[string]expectedRequest
}
// Check interface
var (
_ aptly.Downloader = (*FakeDownloader)(nil)
)
// NewFakeDownloader creates new expected downloader
func NewFakeDownloader() *FakeDownloader {
result := &FakeDownloader{}
result.expected = make([]expectedRequest, 0)
result.anyExpected = make(map[string]expectedRequest)
return result
}
// ExpectResponse installs expectation on upcoming download with response
func (f *FakeDownloader) ExpectResponse(url string, response string) *FakeDownloader {
f.expected = append(f.expected, expectedRequest{URL: url, Response: response})
return f
}
// AnyExpectResponse installs expectation on upcoming download with response in any order (url should be unique)
func (f *FakeDownloader) AnyExpectResponse(url string, response string) *FakeDownloader {
f.anyExpected[url] = expectedRequest{URL: url, Response: response}
return f
}
// ExpectError installs expectation on upcoming download with error
func (f *FakeDownloader) ExpectError(url string, err error) *FakeDownloader {
f.expected = append(f.expected, expectedRequest{URL: url, Err: err})
return f
}
// Empty verifies that are planned downloads have happened
func (f *FakeDownloader) Empty() bool {
return len(f.expected) == 0
}
// GetLength returns content length of given url
func (f *FakeDownloader) GetLength(_ context.Context, url string) (int64, error) {
expectation, err := f.getExpectedRequest(url)
if err != nil {
return -1, err
}
return int64(len(expectation.Response)), nil
}
func (f *FakeDownloader) getExpectedRequest(url string) (*expectedRequest, error) {
var expectation expectedRequest
if len(f.expected) > 0 && f.expected[0].URL == url {
expectation, f.expected = f.expected[0], f.expected[1:]
} else if _, ok := f.anyExpected[url]; ok {
expectation = f.anyExpected[url]
delete(f.anyExpected, url)
} else {
return nil, fmt.Errorf("unexpected request for %s", url)
}
if expectation.Err != nil {
return nil, expectation.Err
}
return &expectation, nil
}
// DownloadWithChecksum performs fake download by matching against first expectation in the queue or any expectation, with cheksum verification
func (f *FakeDownloader) DownloadWithChecksum(_ context.Context, url string, filename string, expected *utils.ChecksumInfo, ignoreMismatch bool) error {
expectation, err := f.getExpectedRequest(url)
if err != nil {
return err
}
err = os.MkdirAll(filepath.Dir(filename), 0755)
if err != nil {
return err
}
outfile, err := os.Create(filename)
if err != nil {
return err
}
defer outfile.Close()
cks := utils.NewChecksumWriter()
w := io.MultiWriter(outfile, cks)
_, err = w.Write([]byte(expectation.Response))
if err != nil {
return err
}
if expected != nil {
if expected.Size != cks.Sum().Size || expected.MD5 != "" && expected.MD5 != cks.Sum().MD5 ||
expected.SHA1 != "" && expected.SHA1 != cks.Sum().SHA1 || expected.SHA256 != "" && expected.SHA256 != cks.Sum().SHA256 {
if ignoreMismatch {
fmt.Printf("WARNING: checksums don't match: %#v != %#v for %s\n", expected, cks.Sum(), url)
} else {
return fmt.Errorf("checksums don't match: %#v != %#v for %s", expected, cks.Sum(), url)
}
}
}
return nil
}
// Download performs fake download by matching against first expectation in the queue
func (f *FakeDownloader) Download(ctx context.Context, url string, filename string) error {
return f.DownloadWithChecksum(ctx, url, filename, nil, false)
}
// GetProgress returns Progress object
func (f *FakeDownloader) GetProgress() aptly.Progress {
return nil
}