gotosocial/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/handler.go

248 lines
7.3 KiB
Go

package runtime
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
"strconv"
"strings"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// ForwardResponseStream forwards the stream from gRPC server to REST client.
func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
rc := http.NewResponseController(w)
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Error("Failed to extract ServerMetadata from context")
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
handleForwardResponseServerMetadata(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var delimiter []byte
if d, ok := marshaler.(Delimited); ok {
delimiter = d.Delimiter()
} else {
delimiter = []byte("\n")
}
var wroteHeader bool
for {
resp, err := recv()
if errors.Is(err, io.EOF) {
return
}
if err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
respRw, err := mux.forwardResponseRewriter(ctx, resp)
if err != nil {
grpclog.Errorf("Rewrite error: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(respRw))
}
var buf []byte
httpBody, isHTTPBody := respRw.(*httpbody.HttpBody)
switch {
case respRw == nil:
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
case isHTTPBody:
buf = httpBody.GetData()
default:
result := map[string]interface{}{"result": respRw}
if rb, ok := respRw.(responseBody); ok {
result["result"] = rb.XXX_ResponseBody()
}
buf, err = marshaler.Marshal(result)
}
if err != nil {
grpclog.Errorf("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
if _, err := w.Write(buf); err != nil {
grpclog.Errorf("Failed to send response chunk: %v", err)
return
}
wroteHeader = true
if _, err := w.Write(delimiter); err != nil {
grpclog.Errorf("Failed to send delimiter chunk: %v", err)
return
}
err = rc.Flush()
if err != nil {
if errors.Is(err, http.ErrNotSupported) {
grpclog.Errorf("Flush not supported in %T", w)
http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
return
}
grpclog.Errorf("Failed to flush response to client: %v", err)
return
}
}
}
func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.HeaderMD {
if h, ok := mux.outgoingHeaderMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
func handleForwardResponseTrailerHeader(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k := range md.TrailerMD {
if h, ok := mux.outgoingTrailerMatcher(k); ok {
w.Header().Add("Trailer", textproto.CanonicalMIMEHeaderKey(h))
}
}
}
func handleForwardResponseTrailer(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.TrailerMD {
if h, ok := mux.outgoingTrailerMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
// responseBody interface contains method for getting field for marshaling to the response body
// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
type responseBody interface {
XXX_ResponseBody() interface{}
}
// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Error("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(req)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
contentType := marshaler.ContentType(resp)
w.Header().Set("Content-Type", contentType)
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
respRw, err := mux.forwardResponseRewriter(ctx, resp)
if err != nil {
grpclog.Errorf("Rewrite error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var buf []byte
if rb, ok := respRw.(responseBody); ok {
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
} else {
buf, err = marshaler.Marshal(respRw)
}
if err != nil {
grpclog.Errorf("Marshal error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
if !doForwardTrailers {
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
}
if _, err = w.Write(buf); err != nil {
grpclog.Errorf("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, mux, md)
}
}
func requestAcceptsTrailers(req *http.Request) bool {
te := req.Header.Get("TE")
return strings.Contains(strings.ToLower(te), "trailers")
}
func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
if len(opts) == 0 {
return nil
}
for _, opt := range opts {
if err := opt(ctx, w, resp); err != nil {
return fmt.Errorf("error handling ForwardResponseOptions: %w", err)
}
}
return nil
}
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error, delimiter []byte) {
st := mux.streamErrorHandler(ctx, err)
msg := errorChunk(st)
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(msg))
w.WriteHeader(HTTPStatusFromCode(st.Code()))
}
buf, err := marshaler.Marshal(msg)
if err != nil {
grpclog.Errorf("Failed to marshal an error: %v", err)
return
}
if _, err := w.Write(buf); err != nil {
grpclog.Errorf("Failed to notify error to client: %v", err)
return
}
if _, err := w.Write(delimiter); err != nil {
grpclog.Errorf("Failed to send delimiter chunk: %v", err)
return
}
}
func errorChunk(st *status.Status) map[string]proto.Message {
return map[string]proto.Message{"error": st.Proto()}
}