caddy/modules/caddyhttp/replacer.go

544 lines
16 KiB
Go

// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddyhttp
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/textproto"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddytls"
)
// NewTestReplacer creates a replacer for an http.Request
// for use in tests that are not in this package
func NewTestReplacer(req *http.Request) *caddy.Replacer {
repl := caddy.NewReplacer()
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl)
*req = *req.WithContext(ctx)
addHTTPVarsToReplacer(repl, req, nil)
return repl
}
func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
SetVar(req.Context(), "start_time", time.Now())
SetVar(req.Context(), "uuid", new(requestID))
httpVars := func(key string) (any, bool) {
if req != nil {
// query string parameters
if strings.HasPrefix(key, reqURIQueryReplPrefix) {
vals := req.URL.Query()[key[len(reqURIQueryReplPrefix):]]
// always return true, since the query param might
// be present only in some requests
return strings.Join(vals, ","), true
}
// request header fields
if strings.HasPrefix(key, reqHeaderReplPrefix) {
field := key[len(reqHeaderReplPrefix):]
vals := req.Header[textproto.CanonicalMIMEHeaderKey(field)]
// always return true, since the header field might
// be present only in some requests
return strings.Join(vals, ","), true
}
// cookies
if strings.HasPrefix(key, reqCookieReplPrefix) {
name := key[len(reqCookieReplPrefix):]
for _, cookie := range req.Cookies() {
if strings.EqualFold(name, cookie.Name) {
// always return true, since the cookie might
// be present only in some requests
return cookie.Value, true
}
}
}
// http.request.tls.*
if strings.HasPrefix(key, reqTLSReplPrefix) {
return getReqTLSReplacement(req, key)
}
switch key {
case "http.request.method":
return req.Method, true
case "http.request.scheme":
if req.TLS != nil {
return "https", true
}
return "http", true
case "http.request.proto":
return req.Proto, true
case "http.request.host":
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host, true // OK; there probably was no port
}
return host, true
case "http.request.port":
_, port, _ := net.SplitHostPort(req.Host)
if portNum, err := strconv.Atoi(port); err == nil {
return portNum, true
}
return port, true
case "http.request.hostport":
return req.Host, true
case "http.request.local":
localAddr, _ := req.Context().Value(http.LocalAddrContextKey).(net.Addr)
return localAddr.String(), true
case "http.request.local.host":
localAddr, _ := req.Context().Value(http.LocalAddrContextKey).(net.Addr)
host, _, err := net.SplitHostPort(localAddr.String())
if err != nil {
// localAddr is host:port for tcp and udp sockets and /unix/socket.path
// for unix sockets. net.SplitHostPort only operates on tcp and udp sockets,
// not unix sockets and will fail with the latter.
// We assume when net.SplitHostPort fails, localAddr is a unix socket and thus
// already "split" and save to return.
return localAddr, true
}
return host, true
case "http.request.local.port":
localAddr, _ := req.Context().Value(http.LocalAddrContextKey).(net.Addr)
_, port, _ := net.SplitHostPort(localAddr.String())
if portNum, err := strconv.Atoi(port); err == nil {
return portNum, true
}
return port, true
case "http.request.remote":
return req.RemoteAddr, true
case "http.request.remote.host":
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
// req.RemoteAddr is host:port for tcp and udp sockets and /unix/socket.path
// for unix sockets. net.SplitHostPort only operates on tcp and udp sockets,
// not unix sockets and will fail with the latter.
// We assume when net.SplitHostPort fails, req.RemoteAddr is a unix socket
// and thus already "split" and save to return.
return req.RemoteAddr, true
}
return host, true
case "http.request.remote.port":
_, port, _ := net.SplitHostPort(req.RemoteAddr)
if portNum, err := strconv.Atoi(port); err == nil {
return portNum, true
}
return port, true
// current URI, including any internal rewrites
case "http.request.uri":
return req.URL.RequestURI(), true
case "http.request.uri.path":
return req.URL.Path, true
case "http.request.uri.path.file":
_, file := path.Split(req.URL.Path)
return file, true
case "http.request.uri.path.dir":
dir, _ := path.Split(req.URL.Path)
return dir, true
case "http.request.uri.path.file.base":
return strings.TrimSuffix(path.Base(req.URL.Path), path.Ext(req.URL.Path)), true
case "http.request.uri.path.file.ext":
return path.Ext(req.URL.Path), true
case "http.request.uri.query":
return req.URL.RawQuery, true
case "http.request.duration":
start := GetVar(req.Context(), "start_time").(time.Time)
return time.Since(start), true
case "http.request.duration_ms":
start := GetVar(req.Context(), "start_time").(time.Time)
return time.Since(start).Seconds() * 1e3, true // multiply seconds to preserve decimal (see #4666)
case "http.request.uuid":
// fetch the UUID for this request
id := GetVar(req.Context(), "uuid").(*requestID)
// set it to this request's access log
extra := req.Context().Value(ExtraLogFieldsCtxKey).(*ExtraLogFields)
extra.Set(zap.String("uuid", id.String()))
return id.String(), true
case "http.request.body":
if req.Body == nil {
return "", true
}
// normally net/http will close the body for us, but since we
// are replacing it with a fake one, we have to ensure we close
// the real body ourselves when we're done
defer req.Body.Close()
// read the request body into a buffer (can't pool because we
// don't know its lifetime and would have to make a copy anyway)
buf := new(bytes.Buffer)
_, _ = io.Copy(buf, req.Body) // can't handle error, so just ignore it
req.Body = io.NopCloser(buf) // replace real body with buffered data
return buf.String(), true
// original request, before any internal changes
case "http.request.orig_method":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
return or.Method, true
case "http.request.orig_uri":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
return or.RequestURI, true
case "http.request.orig_uri.path":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
return or.URL.Path, true
case "http.request.orig_uri.path.file":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
_, file := path.Split(or.URL.Path)
return file, true
case "http.request.orig_uri.path.dir":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
dir, _ := path.Split(or.URL.Path)
return dir, true
case "http.request.orig_uri.query":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
return or.URL.RawQuery, true
}
// remote IP range/prefix (e.g. keep top 24 bits of 1.2.3.4 => "1.2.3.0/24")
// syntax: "/V4,V6" where V4 = IPv4 bits, and V6 = IPv6 bits; if no comma, then same bit length used for both
// (EXPERIMENTAL)
if strings.HasPrefix(key, "http.request.remote.host/") {
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
host = req.RemoteAddr // assume no port, I guess?
}
addr, err := netip.ParseAddr(host)
if err != nil {
return host, true // not an IP address
}
// extract the bits from the end of the placeholder (start after "/") then split on ","
bitsBoth := key[strings.Index(key, "/")+1:]
ipv4BitsStr, ipv6BitsStr, cutOK := strings.Cut(bitsBoth, ",")
bitsStr := ipv4BitsStr
if addr.Is6() && cutOK {
bitsStr = ipv6BitsStr
}
// convert to integer then compute prefix
bits, err := strconv.Atoi(bitsStr)
if err != nil {
return "", true
}
prefix, err := addr.Prefix(bits)
if err != nil {
return "", true
}
return prefix.String(), true
}
// hostname labels
if strings.HasPrefix(key, reqHostLabelsReplPrefix) {
idxStr := key[len(reqHostLabelsReplPrefix):]
idx, err := strconv.Atoi(idxStr)
if err != nil || idx < 0 {
return "", false
}
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
reqHost = req.Host // OK; assume there was no port
}
hostLabels := strings.Split(reqHost, ".")
if idx >= len(hostLabels) {
return "", true
}
return hostLabels[len(hostLabels)-idx-1], true
}
// path parts
if strings.HasPrefix(key, reqURIPathReplPrefix) {
idxStr := key[len(reqURIPathReplPrefix):]
idx, err := strconv.Atoi(idxStr)
if err != nil {
return "", false
}
pathParts := strings.Split(req.URL.Path, "/")
if len(pathParts) > 0 && pathParts[0] == "" {
pathParts = pathParts[1:]
}
if idx < 0 {
return "", false
}
if idx >= len(pathParts) {
return "", true
}
return pathParts[idx], true
}
// orig uri path parts
if strings.HasPrefix(key, reqOrigURIPathReplPrefix) {
idxStr := key[len(reqOrigURIPathReplPrefix):]
idx, err := strconv.Atoi(idxStr)
if err != nil {
return "", false
}
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
pathParts := strings.Split(or.URL.Path, "/")
if len(pathParts) > 0 && pathParts[0] == "" {
pathParts = pathParts[1:]
}
if idx < 0 {
return "", false
}
if idx >= len(pathParts) {
return "", true
}
return pathParts[idx], true
}
// middleware variables
if strings.HasPrefix(key, varsReplPrefix) {
varName := key[len(varsReplPrefix):]
raw := GetVar(req.Context(), varName)
// variables can be dynamic, so always return true
// even when it may not be set; treat as empty then
return raw, true
}
}
if w != nil {
// response header fields
if strings.HasPrefix(key, respHeaderReplPrefix) {
field := key[len(respHeaderReplPrefix):]
vals := w.Header()[textproto.CanonicalMIMEHeaderKey(field)]
// always return true, since the header field might
// be present only in some responses
return strings.Join(vals, ","), true
}
}
switch {
case key == "http.shutting_down":
server := req.Context().Value(ServerCtxKey).(*Server)
server.shutdownAtMu.RLock()
defer server.shutdownAtMu.RUnlock()
return !server.shutdownAt.IsZero(), true
case key == "http.time_until_shutdown":
server := req.Context().Value(ServerCtxKey).(*Server)
server.shutdownAtMu.RLock()
defer server.shutdownAtMu.RUnlock()
if server.shutdownAt.IsZero() {
return nil, true
}
return time.Until(server.shutdownAt), true
}
return nil, false
}
repl.Map(httpVars)
}
func getReqTLSReplacement(req *http.Request, key string) (any, bool) {
if req == nil || req.TLS == nil {
return nil, false
}
if len(key) < len(reqTLSReplPrefix) {
return nil, false
}
field := strings.ToLower(key[len(reqTLSReplPrefix):])
if strings.HasPrefix(field, "client.") {
cert := getTLSPeerCert(req.TLS)
if cert == nil {
return nil, false
}
// subject alternate names (SANs)
if strings.HasPrefix(field, "client.san.") {
field = field[len("client.san."):]
var fieldName string
var fieldValue any
switch {
case strings.HasPrefix(field, "dns_names"):
fieldName = "dns_names"
fieldValue = cert.DNSNames
case strings.HasPrefix(field, "emails"):
fieldName = "emails"
fieldValue = cert.EmailAddresses
case strings.HasPrefix(field, "ips"):
fieldName = "ips"
fieldValue = cert.IPAddresses
case strings.HasPrefix(field, "uris"):
fieldName = "uris"
fieldValue = cert.URIs
default:
return nil, false
}
field = field[len(fieldName):]
// if no index was specified, return the whole list
if field == "" {
return fieldValue, true
}
if len(field) < 2 || field[0] != '.' {
return nil, false
}
field = field[1:] // trim '.' between field name and index
// get the numeric index
idx, err := strconv.Atoi(field)
if err != nil || idx < 0 {
return nil, false
}
// access the indexed element and return it
switch v := fieldValue.(type) {
case []string:
if idx >= len(v) {
return nil, true
}
return v[idx], true
case []net.IP:
if idx >= len(v) {
return nil, true
}
return v[idx], true
case []*url.URL:
if idx >= len(v) {
return nil, true
}
return v[idx], true
}
}
switch field {
case "client.fingerprint":
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true
case "client.public_key", "client.public_key_sha256":
if cert.PublicKey == nil {
return nil, true
}
pubKeyBytes, err := marshalPublicKey(cert.PublicKey)
if err != nil {
return nil, true
}
if strings.HasSuffix(field, "_sha256") {
return fmt.Sprintf("%x", sha256.Sum256(pubKeyBytes)), true
}
return fmt.Sprintf("%x", pubKeyBytes), true
case "client.issuer":
return cert.Issuer, true
case "client.serial":
return cert.SerialNumber, true
case "client.subject":
return cert.Subject, true
case "client.certificate_pem":
block := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
return pem.EncodeToMemory(&block), true
case "client.certificate_der_base64":
return base64.StdEncoding.EncodeToString(cert.Raw), true
default:
return nil, false
}
}
switch field {
case "version":
return caddytls.ProtocolName(req.TLS.Version), true
case "cipher_suite":
return tls.CipherSuiteName(req.TLS.CipherSuite), true
case "resumed":
return req.TLS.DidResume, true
case "proto":
return req.TLS.NegotiatedProtocol, true
case "proto_mutual":
// req.TLS.NegotiatedProtocolIsMutual is deprecated - it's always true.
return true, true
case "server_name":
return req.TLS.ServerName, true
}
return nil, false
}
// marshalPublicKey returns the byte encoding of pubKey.
func marshalPublicKey(pubKey any) ([]byte, error) {
switch key := pubKey.(type) {
case *rsa.PublicKey:
return asn1.Marshal(key)
case *ecdsa.PublicKey:
e, err := key.ECDH()
if err != nil {
return nil, err
}
return e.Bytes(), nil
case ed25519.PublicKey:
return key, nil
}
return nil, fmt.Errorf("unrecognized public key type: %T", pubKey)
}
// getTLSPeerCert retrieves the first peer certificate from a TLS session.
// Returns nil if no peer cert is in use.
func getTLSPeerCert(cs *tls.ConnectionState) *x509.Certificate {
if len(cs.PeerCertificates) == 0 {
return nil
}
return cs.PeerCertificates[0]
}
type requestID struct {
value string
}
// Lazy generates UUID string or return cached value if present
func (rid *requestID) String() string {
if rid.value == "" {
if id, err := uuid.NewRandom(); err == nil {
rid.value = id.String()
}
}
return rid.value
}
const (
reqCookieReplPrefix = "http.request.cookie."
reqHeaderReplPrefix = "http.request.header."
reqHostLabelsReplPrefix = "http.request.host.labels."
reqTLSReplPrefix = "http.request.tls."
reqURIPathReplPrefix = "http.request.uri.path."
reqURIQueryReplPrefix = "http.request.uri.query."
respHeaderReplPrefix = "http.response.header."
varsReplPrefix = "http.vars."
reqOrigURIPathReplPrefix = "http.request.orig_uri.path."
)