mirror of https://go.googlesource.com/go
825 lines
20 KiB
Go
825 lines
20 KiB
Go
// Copyright 2020 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package os_test
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"internal/poll"
|
|
"internal/testpty"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
. "os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/nettest"
|
|
)
|
|
|
|
func TestCopyFileRange(t *testing.T) {
|
|
sizes := []int{
|
|
1,
|
|
42,
|
|
1025,
|
|
syscall.Getpagesize() + 1,
|
|
32769,
|
|
}
|
|
t.Run("Basic", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testCopyFileRange(t, int64(size), -1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("Limited", func(t *testing.T) {
|
|
t.Run("OneLess", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testCopyFileRange(t, int64(size), int64(size)-1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("Half", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testCopyFileRange(t, int64(size), int64(size)/2)
|
|
})
|
|
}
|
|
})
|
|
t.Run("More", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testCopyFileRange(t, int64(size), int64(size)+7)
|
|
})
|
|
}
|
|
})
|
|
})
|
|
t.Run("DoesntTryInAppendMode", func(t *testing.T) {
|
|
dst, src, data, hook := newCopyFileRangeTest(t, 42)
|
|
|
|
dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer dst2.Close()
|
|
|
|
if _, err := io.Copy(dst2, src); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if hook.called {
|
|
t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
|
|
}
|
|
mustSeekStart(t, dst2)
|
|
mustContainData(t, dst2, data) // through traditional means
|
|
})
|
|
t.Run("CopyFileItself", func(t *testing.T) {
|
|
hook := hookCopyFileRange(t)
|
|
|
|
f, err := CreateTemp("", "file-readfrom-itself-test")
|
|
if err != nil {
|
|
t.Fatalf("failed to create tmp file: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
f.Close()
|
|
Remove(f.Name())
|
|
})
|
|
|
|
data := []byte("hello world!")
|
|
if _, err := f.Write(data); err != nil {
|
|
t.Fatalf("failed to create and feed the file: %v", err)
|
|
}
|
|
|
|
if err := f.Sync(); err != nil {
|
|
t.Fatalf("failed to save the file: %v", err)
|
|
}
|
|
|
|
// Rewind it.
|
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
|
t.Fatalf("failed to rewind the file: %v", err)
|
|
}
|
|
|
|
// Read data from the file itself.
|
|
if _, err := io.Copy(f, f); err != nil {
|
|
t.Fatalf("failed to read from the file: %v", err)
|
|
}
|
|
|
|
if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
|
|
t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
|
|
}
|
|
|
|
// Rewind it.
|
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
|
t.Fatalf("failed to rewind the file: %v", err)
|
|
}
|
|
|
|
data2, err := io.ReadAll(f)
|
|
if err != nil {
|
|
t.Fatalf("failed to read from the file: %v", err)
|
|
}
|
|
|
|
// It should wind up a double of the original data.
|
|
if strings.Repeat(string(data), 2) != string(data2) {
|
|
t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
|
|
}
|
|
})
|
|
t.Run("NotRegular", func(t *testing.T) {
|
|
t.Run("BothPipes", func(t *testing.T) {
|
|
hook := hookCopyFileRange(t)
|
|
|
|
pr1, pw1, err := Pipe()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer pr1.Close()
|
|
defer pw1.Close()
|
|
|
|
pr2, pw2, err := Pipe()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer pr2.Close()
|
|
defer pw2.Close()
|
|
|
|
// The pipe is empty, and PIPE_BUF is large enough
|
|
// for this, by (POSIX) definition, so there is no
|
|
// need for an additional goroutine.
|
|
data := []byte("hello")
|
|
if _, err := pw1.Write(data); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pw1.Close()
|
|
|
|
n, err := io.Copy(pw2, pr1)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if n != int64(len(data)) {
|
|
t.Fatalf("transferred %d, want %d", n, len(data))
|
|
}
|
|
if !hook.called {
|
|
t.Fatalf("should have called poll.CopyFileRange")
|
|
}
|
|
pw2.Close()
|
|
mustContainData(t, pr2, data)
|
|
})
|
|
t.Run("DstPipe", func(t *testing.T) {
|
|
dst, src, data, hook := newCopyFileRangeTest(t, 255)
|
|
dst.Close()
|
|
|
|
pr, pw, err := Pipe()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer pr.Close()
|
|
defer pw.Close()
|
|
|
|
n, err := io.Copy(pw, src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if n != int64(len(data)) {
|
|
t.Fatalf("transferred %d, want %d", n, len(data))
|
|
}
|
|
if !hook.called {
|
|
t.Fatalf("should have called poll.CopyFileRange")
|
|
}
|
|
pw.Close()
|
|
mustContainData(t, pr, data)
|
|
})
|
|
t.Run("SrcPipe", func(t *testing.T) {
|
|
dst, src, data, hook := newCopyFileRangeTest(t, 255)
|
|
src.Close()
|
|
|
|
pr, pw, err := Pipe()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer pr.Close()
|
|
defer pw.Close()
|
|
|
|
// The pipe is empty, and PIPE_BUF is large enough
|
|
// for this, by (POSIX) definition, so there is no
|
|
// need for an additional goroutine.
|
|
if _, err := pw.Write(data); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pw.Close()
|
|
|
|
n, err := io.Copy(dst, pr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if n != int64(len(data)) {
|
|
t.Fatalf("transferred %d, want %d", n, len(data))
|
|
}
|
|
if !hook.called {
|
|
t.Fatalf("should have called poll.CopyFileRange")
|
|
}
|
|
mustSeekStart(t, dst)
|
|
mustContainData(t, dst, data)
|
|
})
|
|
})
|
|
t.Run("Nil", func(t *testing.T) {
|
|
var nilFile *File
|
|
anyFile, err := CreateTemp("", "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer Remove(anyFile.Name())
|
|
defer anyFile.Close()
|
|
|
|
if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
|
|
t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
|
|
}
|
|
if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
|
|
t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
|
|
}
|
|
if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
|
|
t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
|
|
}
|
|
|
|
if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
|
|
t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
|
|
}
|
|
if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
|
|
t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
|
|
}
|
|
if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
|
|
t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSpliceFile(t *testing.T) {
|
|
sizes := []int{
|
|
1,
|
|
42,
|
|
1025,
|
|
syscall.Getpagesize() + 1,
|
|
32769,
|
|
}
|
|
t.Run("Basic-TCP", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "tcp", int64(size), -1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("Basic-Unix", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "unix", int64(size), -1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("TCP-To-TTY", func(t *testing.T) {
|
|
testSpliceToTTY(t, "tcp", 32768)
|
|
})
|
|
t.Run("Unix-To-TTY", func(t *testing.T) {
|
|
testSpliceToTTY(t, "unix", 32768)
|
|
})
|
|
t.Run("Limited", func(t *testing.T) {
|
|
t.Run("OneLess-TCP", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "tcp", int64(size), int64(size)-1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("OneLess-Unix", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "unix", int64(size), int64(size)-1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("Half-TCP", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "tcp", int64(size), int64(size)/2)
|
|
})
|
|
}
|
|
})
|
|
t.Run("Half-Unix", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "unix", int64(size), int64(size)/2)
|
|
})
|
|
}
|
|
})
|
|
t.Run("More-TCP", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "tcp", int64(size), int64(size)+1)
|
|
})
|
|
}
|
|
})
|
|
t.Run("More-Unix", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSpliceFile(t, "unix", int64(size), int64(size)+1)
|
|
})
|
|
}
|
|
})
|
|
})
|
|
}
|
|
|
|
func testSpliceFile(t *testing.T, proto string, size, limit int64) {
|
|
dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
|
|
defer cleanup()
|
|
|
|
// If we have a limit, wrap the reader.
|
|
var (
|
|
r io.Reader
|
|
lr *io.LimitedReader
|
|
)
|
|
if limit >= 0 {
|
|
lr = &io.LimitedReader{N: limit, R: src}
|
|
r = lr
|
|
if limit < int64(len(data)) {
|
|
data = data[:limit]
|
|
}
|
|
} else {
|
|
r = src
|
|
}
|
|
// Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice
|
|
n, err := io.Copy(dst, r)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// We should have called poll.Splice with the right file descriptor arguments.
|
|
if n > 0 && !hook.called {
|
|
t.Fatal("expected to called poll.Splice")
|
|
}
|
|
if hook.called && hook.dstfd != int(dst.Fd()) {
|
|
t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
|
|
}
|
|
sc, ok := src.(syscall.Conn)
|
|
if !ok {
|
|
t.Fatalf("server Conn is not a syscall.Conn")
|
|
}
|
|
rc, err := sc.SyscallConn()
|
|
if err != nil {
|
|
t.Fatalf("server Conn SyscallConn error: %v", err)
|
|
}
|
|
if err = rc.Control(func(fd uintptr) {
|
|
if hook.called && hook.srcfd != int(fd) {
|
|
t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
|
|
}
|
|
}); err != nil {
|
|
t.Fatalf("server Conn Control error: %v", err)
|
|
}
|
|
|
|
// Check that the offsets after the transfer make sense, that the size
|
|
// of the transfer was reported correctly, and that the destination
|
|
// file contains exactly the bytes we expect it to contain.
|
|
dstoff, err := dst.Seek(0, io.SeekCurrent)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if dstoff != int64(len(data)) {
|
|
t.Errorf("dstoff = %d, want %d", dstoff, len(data))
|
|
}
|
|
if n != int64(len(data)) {
|
|
t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
|
|
}
|
|
mustSeekStart(t, dst)
|
|
mustContainData(t, dst, data)
|
|
|
|
// If we had a limit, check that it was updated.
|
|
if lr != nil {
|
|
if want := limit - n; lr.N != want {
|
|
t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Issue #59041.
|
|
func testSpliceToTTY(t *testing.T, proto string, size int64) {
|
|
var wg sync.WaitGroup
|
|
|
|
// Call wg.Wait as the final deferred function,
|
|
// because the goroutines may block until some of
|
|
// the deferred Close calls.
|
|
defer wg.Wait()
|
|
|
|
pty, ttyName, err := testpty.Open()
|
|
if err != nil {
|
|
t.Skipf("skipping test because pty open failed: %v", err)
|
|
}
|
|
defer pty.Close()
|
|
|
|
// Open the tty directly, rather than via OpenFile.
|
|
// This bypasses the non-blocking support and is required
|
|
// to recreate the problem in the issue (#59041).
|
|
ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
|
|
if err != nil {
|
|
t.Skipf("skipping test because failed to open tty: %v", err)
|
|
}
|
|
defer syscall.Close(ttyFD)
|
|
|
|
tty := NewFile(uintptr(ttyFD), "tty")
|
|
defer tty.Close()
|
|
|
|
client, server := createSocketPair(t, proto)
|
|
|
|
data := bytes.Repeat([]byte{'a'}, int(size))
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
// The problem (issue #59041) occurs when writing
|
|
// a series of blocks of data. It does not occur
|
|
// when all the data is written at once.
|
|
for i := 0; i < len(data); i += 1024 {
|
|
if _, err := client.Write(data[i : i+1024]); err != nil {
|
|
// If we get here because the client was
|
|
// closed, skip the error.
|
|
if !errors.Is(err, net.ErrClosed) {
|
|
t.Errorf("error writing to socket: %v", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
client.Close()
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
buf := make([]byte, 32)
|
|
for {
|
|
if _, err := pty.Read(buf); err != nil {
|
|
if err != io.EOF && !errors.Is(err, ErrClosed) {
|
|
// An error here doesn't matter for
|
|
// our test.
|
|
t.Logf("error reading from pty: %v", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Close Client to wake up the writing goroutine if necessary.
|
|
defer client.Close()
|
|
|
|
_, err = io.Copy(tty, server)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func testCopyFileRange(t *testing.T, size int64, limit int64) {
|
|
dst, src, data, hook := newCopyFileRangeTest(t, size)
|
|
|
|
// If we have a limit, wrap the reader.
|
|
var (
|
|
realsrc io.Reader
|
|
lr *io.LimitedReader
|
|
)
|
|
if limit >= 0 {
|
|
lr = &io.LimitedReader{N: limit, R: src}
|
|
realsrc = lr
|
|
if limit < int64(len(data)) {
|
|
data = data[:limit]
|
|
}
|
|
} else {
|
|
realsrc = src
|
|
}
|
|
|
|
// Now call ReadFrom (through io.Copy), which will hopefully call
|
|
// poll.CopyFileRange.
|
|
n, err := io.Copy(dst, realsrc)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// If we didn't have a limit, we should have called poll.CopyFileRange
|
|
// with the right file descriptor arguments.
|
|
if limit > 0 && !hook.called {
|
|
t.Fatal("never called poll.CopyFileRange")
|
|
}
|
|
if hook.called && hook.dstfd != int(dst.Fd()) {
|
|
t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
|
|
}
|
|
if hook.called && hook.srcfd != int(src.Fd()) {
|
|
t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
|
|
}
|
|
|
|
// Check that the offsets after the transfer make sense, that the size
|
|
// of the transfer was reported correctly, and that the destination
|
|
// file contains exactly the bytes we expect it to contain.
|
|
dstoff, err := dst.Seek(0, io.SeekCurrent)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
srcoff, err := src.Seek(0, io.SeekCurrent)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if dstoff != srcoff {
|
|
t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
|
|
}
|
|
if dstoff != int64(len(data)) {
|
|
t.Errorf("dstoff = %d, want %d", dstoff, len(data))
|
|
}
|
|
if n != int64(len(data)) {
|
|
t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
|
|
}
|
|
mustSeekStart(t, dst)
|
|
mustContainData(t, dst, data)
|
|
|
|
// If we had a limit, check that it was updated.
|
|
if lr != nil {
|
|
if want := limit - n; lr.N != want {
|
|
t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// newCopyFileRangeTest initializes a new test for copy_file_range.
|
|
//
|
|
// It creates source and destination files, and populates the source file
|
|
// with random data of the specified size. It also hooks package os' call
|
|
// to poll.CopyFileRange and returns the hook so it can be inspected.
|
|
func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
|
|
t.Helper()
|
|
|
|
hook = hookCopyFileRange(t)
|
|
tmp := t.TempDir()
|
|
|
|
src, err := Create(filepath.Join(tmp, "src"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() { src.Close() })
|
|
|
|
dst, err = Create(filepath.Join(tmp, "dst"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() { dst.Close() })
|
|
|
|
// Populate the source file with data, then rewind it, so it can be
|
|
// consumed by copy_file_range(2).
|
|
prng := rand.New(rand.NewSource(time.Now().Unix()))
|
|
data = make([]byte, size)
|
|
prng.Read(data)
|
|
if _, err := src.Write(data); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := src.Seek(0, io.SeekStart); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return dst, src, data, hook
|
|
}
|
|
|
|
// newSpliceFileTest initializes a new test for splice.
|
|
//
|
|
// It creates source sockets and destination file, and populates the source sockets
|
|
// with random data of the specified size. It also hooks package os' call
|
|
// to poll.Splice and returns the hook so it can be inspected.
|
|
func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
|
|
t.Helper()
|
|
|
|
hook := hookSpliceFile(t)
|
|
|
|
client, server := createSocketPair(t, proto)
|
|
|
|
dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() { dst.Close() })
|
|
|
|
randSeed := time.Now().Unix()
|
|
t.Logf("random data seed: %d\n", randSeed)
|
|
prng := rand.New(rand.NewSource(randSeed))
|
|
data := make([]byte, size)
|
|
prng.Read(data)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
client.Write(data)
|
|
client.Close()
|
|
close(done)
|
|
}()
|
|
|
|
return dst, server, data, hook, func() { <-done }
|
|
}
|
|
|
|
// mustContainData ensures that the specified file contains exactly the
|
|
// specified data.
|
|
func mustContainData(t *testing.T, f *File, data []byte) {
|
|
t.Helper()
|
|
|
|
got := make([]byte, len(data))
|
|
if _, err := io.ReadFull(f, got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Equal(got, data) {
|
|
t.Fatalf("didn't get the same data back from %s", f.Name())
|
|
}
|
|
if _, err := f.Read(make([]byte, 1)); err != io.EOF {
|
|
t.Fatalf("not at EOF")
|
|
}
|
|
}
|
|
|
|
func mustSeekStart(t *testing.T, f *File) {
|
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
|
|
h := new(copyFileRangeHook)
|
|
h.install()
|
|
t.Cleanup(h.uninstall)
|
|
return h
|
|
}
|
|
|
|
type copyFileRangeHook struct {
|
|
called bool
|
|
dstfd int
|
|
srcfd int
|
|
remain int64
|
|
|
|
written int64
|
|
handled bool
|
|
err error
|
|
|
|
original func(dst, src *poll.FD, remain int64) (int64, bool, error)
|
|
}
|
|
|
|
func (h *copyFileRangeHook) install() {
|
|
h.original = *PollCopyFileRangeP
|
|
*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
|
|
h.called = true
|
|
h.dstfd = dst.Sysfd
|
|
h.srcfd = src.Sysfd
|
|
h.remain = remain
|
|
h.written, h.handled, h.err = h.original(dst, src, remain)
|
|
return h.written, h.handled, h.err
|
|
}
|
|
}
|
|
|
|
func (h *copyFileRangeHook) uninstall() {
|
|
*PollCopyFileRangeP = h.original
|
|
}
|
|
|
|
func hookSpliceFile(t *testing.T) *spliceFileHook {
|
|
h := new(spliceFileHook)
|
|
h.install()
|
|
t.Cleanup(h.uninstall)
|
|
return h
|
|
}
|
|
|
|
type spliceFileHook struct {
|
|
called bool
|
|
dstfd int
|
|
srcfd int
|
|
remain int64
|
|
|
|
written int64
|
|
handled bool
|
|
err error
|
|
|
|
original func(dst, src *poll.FD, remain int64) (int64, bool, error)
|
|
}
|
|
|
|
func (h *spliceFileHook) install() {
|
|
h.original = *PollSpliceFile
|
|
*PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
|
|
h.called = true
|
|
h.dstfd = dst.Sysfd
|
|
h.srcfd = src.Sysfd
|
|
h.remain = remain
|
|
h.written, h.handled, h.err = h.original(dst, src, remain)
|
|
return h.written, h.handled, h.err
|
|
}
|
|
}
|
|
|
|
func (h *spliceFileHook) uninstall() {
|
|
*PollSpliceFile = h.original
|
|
}
|
|
|
|
// On some kernels copy_file_range fails on files in /proc.
|
|
func TestProcCopy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const cmdlineFile = "/proc/self/cmdline"
|
|
cmdline, err := ReadFile(cmdlineFile)
|
|
if err != nil {
|
|
t.Skipf("can't read /proc file: %v", err)
|
|
}
|
|
in, err := Open(cmdlineFile)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer in.Close()
|
|
outFile := filepath.Join(t.TempDir(), "cmdline")
|
|
out, err := Create(outFile)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := io.Copy(out, in); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := out.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
copy, err := ReadFile(outFile)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Equal(cmdline, copy) {
|
|
t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
|
|
}
|
|
}
|
|
|
|
func TestGetPollFDAndNetwork(t *testing.T) {
|
|
t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
|
|
t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
|
|
}
|
|
|
|
func testGetPollFDAndNetwork(t *testing.T, proto string) {
|
|
_, server := createSocketPair(t, proto)
|
|
sc, ok := server.(syscall.Conn)
|
|
if !ok {
|
|
t.Fatalf("server Conn is not a syscall.Conn")
|
|
}
|
|
rc, err := sc.SyscallConn()
|
|
if err != nil {
|
|
t.Fatalf("server SyscallConn error: %v", err)
|
|
}
|
|
if err = rc.Control(func(fd uintptr) {
|
|
pfd, network := GetPollFDAndNetwork(server)
|
|
if pfd == nil {
|
|
t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
|
|
}
|
|
if string(network) != proto {
|
|
t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
|
|
}
|
|
if pfd.Sysfd != int(fd) {
|
|
t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
|
|
}
|
|
if !pfd.IsStream {
|
|
t.Fatalf("expected IsStream to be true")
|
|
}
|
|
if err = pfd.Init(proto, true); err == nil {
|
|
t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
|
|
}
|
|
}); err != nil {
|
|
t.Fatalf("server Control error: %v", err)
|
|
}
|
|
}
|
|
|
|
func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
|
|
t.Helper()
|
|
if !nettest.TestableNetwork(proto) {
|
|
t.Skipf("%s does not support %q", runtime.GOOS, proto)
|
|
}
|
|
|
|
ln, err := nettest.NewLocalListener(proto)
|
|
if err != nil {
|
|
t.Fatalf("NewLocalListener error: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if ln != nil {
|
|
ln.Close()
|
|
}
|
|
if client != nil {
|
|
client.Close()
|
|
}
|
|
if server != nil {
|
|
server.Close()
|
|
}
|
|
})
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
var err error
|
|
server, err = ln.Accept()
|
|
if err != nil {
|
|
t.Errorf("Accept new connection error: %v", err)
|
|
}
|
|
ch <- struct{}{}
|
|
}()
|
|
client, err = net.Dial(proto, ln.Addr().String())
|
|
<-ch
|
|
if err != nil {
|
|
t.Fatalf("Dial new connection error: %v", err)
|
|
}
|
|
return client, server
|
|
}
|