mirror of https://go.googlesource.com/go
163 lines
4.0 KiB
Go
163 lines
4.0 KiB
Go
// Copyright 2023 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package os_test
|
|
|
|
import (
|
|
"bytes"
|
|
"internal/poll"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
. "os"
|
|
"strconv"
|
|
"syscall"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestSendFile(t *testing.T) {
|
|
sizes := []int{
|
|
1,
|
|
42,
|
|
1025,
|
|
syscall.Getpagesize() + 1,
|
|
32769,
|
|
}
|
|
t.Run("sendfile-to-unix", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSendFile(t, "unix", int64(size))
|
|
})
|
|
}
|
|
})
|
|
t.Run("sendfile-to-tcp", func(t *testing.T) {
|
|
for _, size := range sizes {
|
|
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
|
testSendFile(t, "tcp", int64(size))
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func testSendFile(t *testing.T, proto string, size int64) {
|
|
dst, src, recv, data, hook := newSendFileTest(t, proto, size)
|
|
|
|
// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
|
|
n, err := io.Copy(dst, src)
|
|
if err != nil {
|
|
t.Fatalf("io.Copy error: %v", err)
|
|
}
|
|
|
|
// We should have called poll.Splice with the right file descriptor arguments.
|
|
if n > 0 && !hook.called {
|
|
t.Fatal("expected to called poll.SendFile")
|
|
}
|
|
if hook.called && hook.srcfd != int(src.Fd()) {
|
|
t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
|
|
}
|
|
sc, ok := dst.(syscall.Conn)
|
|
if !ok {
|
|
t.Fatalf("destination is not a syscall.Conn")
|
|
}
|
|
rc, err := sc.SyscallConn()
|
|
if err != nil {
|
|
t.Fatalf("destination SyscallConn error: %v", err)
|
|
}
|
|
if err = rc.Control(func(fd uintptr) {
|
|
if hook.called && hook.dstfd != int(fd) {
|
|
t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
|
|
}
|
|
}); err != nil {
|
|
t.Fatalf("destination Conn Control error: %v", err)
|
|
}
|
|
|
|
// Verify the data size and content.
|
|
dataSize := len(data)
|
|
dstData := make([]byte, dataSize)
|
|
m, err := io.ReadFull(recv, dstData)
|
|
if err != nil {
|
|
t.Fatalf("server Conn Read error: %v", err)
|
|
}
|
|
if n != int64(dataSize) {
|
|
t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
|
|
}
|
|
if m != dataSize {
|
|
t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
|
|
}
|
|
if !bytes.Equal(dstData, data) {
|
|
t.Errorf("data mismatch, got %s, want %s", dstData, data)
|
|
}
|
|
}
|
|
|
|
// newSendFileTest initializes a new test for sendfile.
|
|
//
|
|
// It creates source file and destination sockets, and populates the source file
|
|
// with random data of the specified size. It also hooks package os' call
|
|
// to poll.Sendfile and returns the hook so it can be inspected.
|
|
func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
|
|
t.Helper()
|
|
|
|
hook := hookSendFile(t)
|
|
|
|
client, server := createSocketPair(t, proto)
|
|
tempFile, data := createTempFile(t, size)
|
|
|
|
return client, tempFile, server, data, hook
|
|
}
|
|
|
|
func hookSendFile(t *testing.T) *sendFileHook {
|
|
h := new(sendFileHook)
|
|
orig := poll.TestHookDidSendFile
|
|
t.Cleanup(func() {
|
|
poll.TestHookDidSendFile = orig
|
|
})
|
|
poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
|
|
h.called = true
|
|
h.dstfd = dstFD.Sysfd
|
|
h.srcfd = src
|
|
h.written = written
|
|
h.err = err
|
|
h.handled = handled
|
|
}
|
|
return h
|
|
}
|
|
|
|
type sendFileHook struct {
|
|
called bool
|
|
dstfd int
|
|
srcfd int
|
|
|
|
written int64
|
|
handled bool
|
|
err error
|
|
}
|
|
|
|
func createTempFile(t *testing.T, size int64) (*File, []byte) {
|
|
f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temporary file: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
f.Close()
|
|
})
|
|
|
|
randSeed := time.Now().Unix()
|
|
t.Logf("random data seed: %d\n", randSeed)
|
|
prng := rand.New(rand.NewSource(randSeed))
|
|
data := make([]byte, size)
|
|
prng.Read(data)
|
|
if _, err := f.Write(data); err != nil {
|
|
t.Fatalf("failed to create and feed the file: %v", err)
|
|
}
|
|
if err := f.Sync(); err != nil {
|
|
t.Fatalf("failed to save the file: %v", err)
|
|
}
|
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
|
t.Fatalf("failed to rewind the file: %v", err)
|
|
}
|
|
|
|
return f, data
|
|
}
|