gotosocial/internal/middleware/headerfilter.go

252 lines
5.8 KiB
Go

// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package middleware
import (
"sync"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
var (
allowMatches = matchstats{m: make(map[string]uint64)}
blockMatches = matchstats{m: make(map[string]uint64)}
)
// matchstats is a simple statistics
// counter for header filter matches.
// TODO: replace with otel.
type matchstats struct {
m map[string]uint64
l sync.Mutex
}
func (m *matchstats) Add(hdr, regex string) {
m.l.Lock()
key := hdr + ":" + regex
m.m[key]++
m.l.Unlock()
}
// HeaderFilter returns a gin middleware handler that provides HTTP
// request blocking (filtering) based on database allow / block filters.
func HeaderFilter(state *state.State) gin.HandlerFunc {
switch mode := config.GetAdvancedHeaderFilterMode(); mode {
case config.RequestHeaderFilterModeDisabled:
return func(ctx *gin.Context) {}
case config.RequestHeaderFilterModeAllow:
return headerFilterAllowMode(state)
case config.RequestHeaderFilterModeBlock:
return headerFilterBlockMode(state)
default:
panic("unrecognized filter mode: " + mode)
}
}
func headerFilterAllowMode(state *state.State) func(c *gin.Context) {
_ = *state //nolint
// Allowlist mode: explicit block takes
// precedence over explicit allow.
//
// Headers that have neither block
// or allow entries are blocked.
return func(c *gin.Context) {
// Check if header is explicitly blocked.
block, err := isHeaderBlocked(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if block {
respondBlocked(c)
return
}
// Check if header is missing explicit allow.
notAllow, err := isHeaderNotAllowed(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if notAllow {
respondBlocked(c)
return
}
// Allowed!
c.Next()
}
}
func headerFilterBlockMode(state *state.State) func(c *gin.Context) {
_ = *state //nolint
// Blocklist/default mode: explicit allow
// takes precedence over explicit block.
//
// Headers that have neither block
// or allow entries are allowed.
return func(c *gin.Context) {
// Check if header is explicitly allowed.
allow, err := isHeaderAllowed(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if !allow {
// Check if header is explicitly blocked.
block, err := isHeaderBlocked(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if block {
respondBlocked(c)
return
}
}
// Allowed!
c.Next()
}
}
func isHeaderBlocked(state *state.State, c *gin.Context) (bool, error) {
var (
ctx = c.Request.Context()
hdr = c.Request.Header
)
// Perform an explicit is-blocked check on request header.
key, expr, err := state.DB.BlockHeaderRegularMatch(ctx, hdr)
switch err {
case nil:
break
case headerfilter.ErrLargeHeaderValue:
log.Warn(ctx, "large header value")
key = "*" // block large headers
default:
err := gtserror.Newf("error checking header: %w", err)
return false, err
}
if key != "" {
if expr != "" {
// Increment block matches stat.
// TODO: replace expvar with build
// taggable metrics types in State{}.
blockMatches.Add(key, expr)
}
// A header was matched against!
// i.e. this request is blocked.
return true, nil
}
return false, nil
}
func isHeaderAllowed(state *state.State, c *gin.Context) (bool, error) {
var (
ctx = c.Request.Context()
hdr = c.Request.Header
)
// Perform an explicit is-allowed check on request header.
key, expr, err := state.DB.AllowHeaderRegularMatch(ctx, hdr)
switch err {
case nil:
break
case headerfilter.ErrLargeHeaderValue:
log.Warn(ctx, "large header value")
key = "" // block large headers
default:
err := gtserror.Newf("error checking header: %w", err)
return false, err
}
if key != "" {
if expr != "" {
// Increment allow matches stat.
// TODO: replace expvar with build
// taggable metrics types in State{}.
allowMatches.Add(key, expr)
}
// A header was matched against!
// i.e. this request is allowed.
return true, nil
}
return false, nil
}
func isHeaderNotAllowed(state *state.State, c *gin.Context) (bool, error) {
var (
ctx = c.Request.Context()
hdr = c.Request.Header
)
// Perform an explicit is-NOT-allowed check on request header.
key, expr, err := state.DB.AllowHeaderInverseMatch(ctx, hdr)
switch err {
case nil:
break
case headerfilter.ErrLargeHeaderValue:
log.Warn(ctx, "large header value")
key = "*" // block large headers
default:
err := gtserror.Newf("error checking header: %w", err)
return false, err
}
if key != "" {
if expr != "" {
// Increment allow matches stat.
// TODO: replace expvar with build
// taggable metrics types in State{}.
allowMatches.Add(key, expr)
}
// A header was matched against!
// i.e. request is NOT allowed.
return true, nil
}
return false, nil
}