mirror of https://go.googlesource.com/go
542 lines
14 KiB
Go
542 lines
14 KiB
Go
// Copyright 2018 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
//go:build linux
|
|
|
|
package net
|
|
|
|
import (
|
|
"internal/poll"
|
|
"io"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
"testing"
|
|
)
|
|
|
|
func TestSplice(t *testing.T) {
|
|
t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
|
|
if !testableNetwork("unixgram") {
|
|
t.Skip("skipping unix-to-tcp tests")
|
|
}
|
|
t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
|
|
t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
|
|
t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
|
|
t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
|
|
t.Run("no-unixpacket", testSpliceNoUnixpacket)
|
|
t.Run("no-unixgram", testSpliceNoUnixgram)
|
|
}
|
|
|
|
func testSpliceToFile(t *testing.T, upNet, downNet string) {
|
|
t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
|
|
t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
|
|
t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
|
|
t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
|
|
t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
|
|
t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
|
|
}
|
|
|
|
func testSplice(t *testing.T, upNet, downNet string) {
|
|
t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
|
|
t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
|
|
t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
|
|
t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
|
|
t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
|
|
t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
|
|
t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
|
|
t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
|
|
}
|
|
|
|
type spliceTestCase struct {
|
|
upNet, downNet string
|
|
|
|
chunkSize, totalSize int
|
|
limitReadSize int
|
|
}
|
|
|
|
func (tc spliceTestCase) test(t *testing.T) {
|
|
hook := hookSplice(t)
|
|
|
|
// We need to use the actual size for startTestSocketPeer when testing with LimitedReader,
|
|
// otherwise the child process created in startTestSocketPeer will hang infinitely because of
|
|
// the mismatch of data size to transfer.
|
|
size := tc.totalSize
|
|
if tc.limitReadSize > 0 {
|
|
if tc.limitReadSize < size {
|
|
size = tc.limitReadSize
|
|
}
|
|
}
|
|
|
|
clientUp, serverUp := spawnTestSocketPair(t, tc.upNet)
|
|
defer serverUp.Close()
|
|
cleanup, err := startTestSocketPeer(t, clientUp, "w", tc.chunkSize, size)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer cleanup(t)
|
|
clientDown, serverDown := spawnTestSocketPair(t, tc.downNet)
|
|
defer serverDown.Close()
|
|
cleanup, err = startTestSocketPeer(t, clientDown, "r", tc.chunkSize, size)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer cleanup(t)
|
|
|
|
var r io.Reader = serverUp
|
|
if tc.limitReadSize > 0 {
|
|
r = &io.LimitedReader{
|
|
N: int64(tc.limitReadSize),
|
|
R: serverUp,
|
|
}
|
|
defer serverUp.Close()
|
|
}
|
|
n, err := io.Copy(serverDown, r)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if want := int64(size); want != n {
|
|
t.Errorf("want %d bytes spliced, got %d", want, n)
|
|
}
|
|
|
|
if tc.limitReadSize > 0 {
|
|
wantN := 0
|
|
if tc.limitReadSize > size {
|
|
wantN = tc.limitReadSize - size
|
|
}
|
|
|
|
if n := r.(*io.LimitedReader).N; n != int64(wantN) {
|
|
t.Errorf("r.N = %d, want %d", n, wantN)
|
|
}
|
|
}
|
|
|
|
// poll.Splice is expected to be called when the source is not
|
|
// a wrapper or the destination is TCPConn.
|
|
if tc.limitReadSize == 0 || tc.downNet == "tcp" {
|
|
// We should have called poll.Splice with the right file descriptor arguments.
|
|
if n > 0 && !hook.called {
|
|
t.Fatal("expected poll.Splice to be called")
|
|
}
|
|
|
|
verifySpliceFds(t, serverDown, hook, "dst")
|
|
verifySpliceFds(t, serverUp, hook, "src")
|
|
|
|
// poll.Splice is expected to handle the data transmission successfully.
|
|
if !hook.handled || hook.written != int64(size) || hook.err != nil {
|
|
t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v",
|
|
size, hook.handled, hook.written, hook.err)
|
|
}
|
|
} else if hook.called {
|
|
// poll.Splice will certainly not be called when the source
|
|
// is a wrapper and the destination is not TCPConn.
|
|
t.Errorf("expected poll.Splice not be called")
|
|
}
|
|
}
|
|
|
|
func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) {
|
|
t.Helper()
|
|
|
|
sc, ok := c.(syscall.Conn)
|
|
if !ok {
|
|
t.Fatalf("expected syscall.Conn")
|
|
}
|
|
rc, err := sc.SyscallConn()
|
|
if err != nil {
|
|
t.Fatalf("syscall.Conn.SyscallConn error: %v", err)
|
|
}
|
|
var hookFd int
|
|
switch fdType {
|
|
case "src":
|
|
hookFd = hook.srcfd
|
|
case "dst":
|
|
hookFd = hook.dstfd
|
|
default:
|
|
t.Fatalf("unknown fdType %q", fdType)
|
|
}
|
|
if err := rc.Control(func(fd uintptr) {
|
|
if hook.called && hookFd != int(fd) {
|
|
t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd))
|
|
}
|
|
}); err != nil {
|
|
t.Fatalf("syscall.RawConn.Control error: %v", err)
|
|
}
|
|
}
|
|
|
|
func (tc spliceTestCase) testFile(t *testing.T) {
|
|
hook := hookSplice(t)
|
|
|
|
// We need to use the actual size for startTestSocketPeer when testing with LimitedReader,
|
|
// otherwise the child process created in startTestSocketPeer will hang infinitely because of
|
|
// the mismatch of data size to transfer.
|
|
actualSize := tc.totalSize
|
|
if tc.limitReadSize > 0 {
|
|
if tc.limitReadSize < actualSize {
|
|
actualSize = tc.limitReadSize
|
|
}
|
|
}
|
|
|
|
f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
client, server := spawnTestSocketPair(t, tc.upNet)
|
|
defer server.Close()
|
|
|
|
cleanup, err := startTestSocketPeer(t, client, "w", tc.chunkSize, actualSize)
|
|
if err != nil {
|
|
client.Close()
|
|
t.Fatal("failed to start splice client:", err)
|
|
}
|
|
defer cleanup(t)
|
|
|
|
var r io.Reader = server
|
|
if tc.limitReadSize > 0 {
|
|
r = &io.LimitedReader{
|
|
N: int64(tc.limitReadSize),
|
|
R: r,
|
|
}
|
|
}
|
|
|
|
got, err := io.Copy(f, r)
|
|
if err != nil {
|
|
t.Fatalf("failed to ReadFrom with error: %v", err)
|
|
}
|
|
|
|
// We shouldn't have called poll.Splice in TCPConn.WriteTo,
|
|
// it's supposed to be called from File.ReadFrom.
|
|
if got > 0 && hook.called {
|
|
t.Error("expected not poll.Splice to be called")
|
|
}
|
|
|
|
if want := int64(actualSize); got != want {
|
|
t.Errorf("got %d bytes, want %d", got, want)
|
|
}
|
|
if tc.limitReadSize > 0 {
|
|
wantN := 0
|
|
if tc.limitReadSize > actualSize {
|
|
wantN = tc.limitReadSize - actualSize
|
|
}
|
|
|
|
if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
|
|
t.Errorf("r.N = %d, want %d", gotN, wantN)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
|
|
// UnixConn doesn't implement io.ReaderFrom, which will fail
|
|
// the following test in asserting a UnixConn to be an io.ReaderFrom,
|
|
// so skip this test.
|
|
if downNet == "unix" {
|
|
t.Skip("skipping test on unix socket")
|
|
}
|
|
|
|
hook := hookSplice(t)
|
|
|
|
clientUp, serverUp := spawnTestSocketPair(t, upNet)
|
|
defer clientUp.Close()
|
|
clientDown, serverDown := spawnTestSocketPair(t, downNet)
|
|
defer clientDown.Close()
|
|
defer serverDown.Close()
|
|
|
|
serverUp.Close()
|
|
|
|
// We'd like to call net.spliceFrom here and check the handled return
|
|
// value, but we disable splice on old Linux kernels.
|
|
//
|
|
// In that case, poll.Splice and net.spliceFrom return a non-nil error
|
|
// and handled == false. We'd ideally like to see handled == true
|
|
// because the source reader is at EOF, but if we're running on an old
|
|
// kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
|
|
// because we won't touch the reader at all.
|
|
//
|
|
// Trying to untangle the errors from net.spliceFrom and match them
|
|
// against the errors created by the poll package would be brittle,
|
|
// so this is a higher level test.
|
|
//
|
|
// The following ReadFrom should return immediately, regardless of
|
|
// whether splice is disabled or not. The other side should then
|
|
// get a goodbye signal. Test for the goodbye signal.
|
|
msg := "bye"
|
|
go func() {
|
|
serverDown.(io.ReaderFrom).ReadFrom(serverUp)
|
|
io.WriteString(serverDown, msg)
|
|
}()
|
|
|
|
buf := make([]byte, 3)
|
|
n, err := io.ReadFull(clientDown, buf)
|
|
if err != nil {
|
|
t.Errorf("clientDown: %v", err)
|
|
}
|
|
if string(buf) != msg {
|
|
t.Errorf("clientDown got %q, want %q", buf, msg)
|
|
}
|
|
|
|
// We should have called poll.Splice with the right file descriptor arguments.
|
|
if n > 0 && !hook.called {
|
|
t.Fatal("expected poll.Splice to be called")
|
|
}
|
|
|
|
verifySpliceFds(t, serverDown, hook, "dst")
|
|
|
|
// poll.Splice is expected to handle the data transmission but fail
|
|
// when working with a closed endpoint, return an error.
|
|
if !hook.handled || hook.written > 0 || hook.err == nil {
|
|
t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v",
|
|
hook.handled, hook.written, hook.err)
|
|
}
|
|
}
|
|
|
|
func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
|
|
front := newLocalListener(t, upNet)
|
|
defer front.Close()
|
|
back := newLocalListener(t, downNet)
|
|
defer back.Close()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
proxy := func() {
|
|
src, err := front.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
dst, err := Dial(downNet, back.Addr().String())
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer dst.Close()
|
|
defer src.Close()
|
|
go func() {
|
|
io.Copy(src, dst)
|
|
wg.Done()
|
|
}()
|
|
go func() {
|
|
io.Copy(dst, src)
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
go proxy()
|
|
|
|
toFront, err := Dial(upNet, front.Addr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
io.WriteString(toFront, "foo")
|
|
toFront.Close()
|
|
|
|
fromProxy, err := back.Accept()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer fromProxy.Close()
|
|
|
|
_, err = io.ReadAll(fromProxy)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func testSpliceNoUnixpacket(t *testing.T) {
|
|
clientUp, serverUp := spawnTestSocketPair(t, "unixpacket")
|
|
defer clientUp.Close()
|
|
defer serverUp.Close()
|
|
clientDown, serverDown := spawnTestSocketPair(t, "tcp")
|
|
defer clientDown.Close()
|
|
defer serverDown.Close()
|
|
// If splice called poll.Splice here, we'd get err == syscall.EINVAL
|
|
// and handled == false. If poll.Splice gets an EINVAL on the first
|
|
// try, it assumes the kernel it's running on doesn't support splice
|
|
// for unix sockets and returns handled == false. This works for our
|
|
// purposes by somewhat of an accident, but is not entirely correct.
|
|
//
|
|
// What we want is err == nil and handled == false, i.e. we never
|
|
// called poll.Splice, because we know the unix socket's network.
|
|
_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
|
|
if err != nil || handled != false {
|
|
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
|
|
}
|
|
}
|
|
|
|
func testSpliceNoUnixgram(t *testing.T) {
|
|
addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.Remove(addr.Name)
|
|
up, err := ListenUnixgram("unixgram", addr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer up.Close()
|
|
clientDown, serverDown := spawnTestSocketPair(t, "tcp")
|
|
defer clientDown.Close()
|
|
defer serverDown.Close()
|
|
// Analogous to testSpliceNoUnixpacket.
|
|
_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
|
|
if err != nil || handled != false {
|
|
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
|
|
}
|
|
}
|
|
|
|
func BenchmarkSplice(b *testing.B) {
|
|
testHookUninstaller.Do(uninstallTestHooks)
|
|
|
|
b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
|
|
b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
|
|
b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
|
|
}
|
|
|
|
func benchSplice(b *testing.B, upNet, downNet string) {
|
|
for i := 0; i <= 10; i++ {
|
|
chunkSize := 1 << uint(i+10)
|
|
tc := spliceTestCase{
|
|
upNet: upNet,
|
|
downNet: downNet,
|
|
chunkSize: chunkSize,
|
|
}
|
|
|
|
b.Run(strconv.Itoa(chunkSize), tc.bench)
|
|
}
|
|
}
|
|
|
|
func (tc spliceTestCase) bench(b *testing.B) {
|
|
// To benchmark the genericReadFrom code path, set this to false.
|
|
useSplice := true
|
|
|
|
clientUp, serverUp := spawnTestSocketPair(b, tc.upNet)
|
|
defer serverUp.Close()
|
|
|
|
cleanup, err := startTestSocketPeer(b, clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer cleanup(b)
|
|
|
|
clientDown, serverDown := spawnTestSocketPair(b, tc.downNet)
|
|
defer serverDown.Close()
|
|
|
|
cleanup, err = startTestSocketPeer(b, clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer cleanup(b)
|
|
|
|
b.SetBytes(int64(tc.chunkSize))
|
|
b.ResetTimer()
|
|
|
|
if useSplice {
|
|
_, err := io.Copy(serverDown, serverUp)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
} else {
|
|
type onlyReader struct {
|
|
io.Reader
|
|
}
|
|
_, err := io.Copy(serverDown, onlyReader{serverUp})
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkSpliceFile(b *testing.B) {
|
|
b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
|
|
b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
|
|
}
|
|
|
|
func benchmarkSpliceFile(b *testing.B, proto string) {
|
|
for i := 0; i <= 10; i++ {
|
|
size := 1 << (i + 10)
|
|
bench := spliceFileBench{
|
|
proto: proto,
|
|
chunkSize: size,
|
|
}
|
|
b.Run(strconv.Itoa(size), bench.benchSpliceFile)
|
|
}
|
|
}
|
|
|
|
type spliceFileBench struct {
|
|
proto string
|
|
chunkSize int
|
|
}
|
|
|
|
func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
|
|
f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
totalSize := b.N * bench.chunkSize
|
|
|
|
client, server := spawnTestSocketPair(b, bench.proto)
|
|
defer server.Close()
|
|
|
|
cleanup, err := startTestSocketPeer(b, client, "w", bench.chunkSize, totalSize)
|
|
if err != nil {
|
|
client.Close()
|
|
b.Fatalf("failed to start splice client: %v", err)
|
|
}
|
|
defer cleanup(b)
|
|
|
|
b.ReportAllocs()
|
|
b.SetBytes(int64(bench.chunkSize))
|
|
b.ResetTimer()
|
|
|
|
got, err := io.Copy(f, server)
|
|
if err != nil {
|
|
b.Fatalf("failed to ReadFrom with error: %v", err)
|
|
}
|
|
if want := int64(totalSize); got != want {
|
|
b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
|
|
}
|
|
}
|
|
|
|
func hookSplice(t *testing.T) *spliceHook {
|
|
t.Helper()
|
|
|
|
h := new(spliceHook)
|
|
h.install()
|
|
t.Cleanup(h.uninstall)
|
|
return h
|
|
}
|
|
|
|
type spliceHook struct {
|
|
called bool
|
|
dstfd int
|
|
srcfd int
|
|
remain int64
|
|
|
|
written int64
|
|
handled bool
|
|
err error
|
|
|
|
original func(dst, src *poll.FD, remain int64) (int64, bool, error)
|
|
}
|
|
|
|
func (h *spliceHook) install() {
|
|
h.original = pollSplice
|
|
pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
|
|
h.called = true
|
|
h.dstfd = dst.Sysfd
|
|
h.srcfd = src.Sysfd
|
|
h.remain = remain
|
|
h.written, h.handled, h.err = h.original(dst, src, remain)
|
|
return h.written, h.handled, h.err
|
|
}
|
|
}
|
|
|
|
func (h *spliceHook) uninstall() {
|
|
pollSplice = h.original
|
|
}
|