money/vendor/go.mau.fi/util/exhttp/handleerrors.go

99 lines
2.2 KiB
Go

// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package exhttp
import (
"bufio"
"encoding/json"
"fmt"
"net"
"net/http"
)
type ErrorBodies struct {
NotFound json.RawMessage
MethodNotAllowed json.RawMessage
}
func HandleErrors(gen ErrorBodies) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(&bodyOverrider{
ResponseWriter: w,
statusNotFoundBody: gen.NotFound,
statusMethodNotAllowedBody: gen.MethodNotAllowed,
}, r)
})
}
}
type bodyOverrider struct {
http.ResponseWriter
code int
override bool
written bool
hijacked bool
statusNotFoundBody json.RawMessage
statusMethodNotAllowedBody json.RawMessage
}
var (
_ http.ResponseWriter = (*bodyOverrider)(nil)
_ http.Flusher = (*bodyOverrider)(nil)
_ http.Hijacker = (*bodyOverrider)(nil)
)
func (b *bodyOverrider) WriteHeader(code int) {
if !b.hijacked &&
b.Header().Get("Content-Type") == "text/plain; charset=utf-8" &&
(code == http.StatusNotFound || code == http.StatusMethodNotAllowed) {
b.Header().Set("Content-Type", "application/json")
b.override = true
}
b.code = code
b.ResponseWriter.WriteHeader(code)
}
func (b *bodyOverrider) Write(body []byte) (n int, err error) {
if b.override {
n = len(body)
if !b.written {
switch b.code {
case http.StatusNotFound:
_, err = b.ResponseWriter.Write(b.statusNotFoundBody)
case http.StatusMethodNotAllowed:
_, err = b.ResponseWriter.Write(b.statusMethodNotAllowedBody)
}
}
b.written = true
return
}
return b.ResponseWriter.Write(body)
}
func (b *bodyOverrider) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := b.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("HandleErrors: %T does not implement http.Hijacker", b.ResponseWriter)
}
b.hijacked = true
return hijacker.Hijack()
}
func (b *bodyOverrider) Flush() {
flusher, ok := b.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}