99 lines
2.2 KiB
Go
99 lines
2.2 KiB
Go
// Copyright (c) 2024 Sumner Evans
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package exhttp
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
)
|
|
|
|
type ErrorBodies struct {
|
|
NotFound json.RawMessage
|
|
MethodNotAllowed json.RawMessage
|
|
}
|
|
|
|
func HandleErrors(gen ErrorBodies) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
next.ServeHTTP(&bodyOverrider{
|
|
ResponseWriter: w,
|
|
statusNotFoundBody: gen.NotFound,
|
|
statusMethodNotAllowedBody: gen.MethodNotAllowed,
|
|
}, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type bodyOverrider struct {
|
|
http.ResponseWriter
|
|
|
|
code int
|
|
override bool
|
|
written bool
|
|
|
|
hijacked bool
|
|
|
|
statusNotFoundBody json.RawMessage
|
|
statusMethodNotAllowedBody json.RawMessage
|
|
}
|
|
|
|
var (
|
|
_ http.ResponseWriter = (*bodyOverrider)(nil)
|
|
_ http.Flusher = (*bodyOverrider)(nil)
|
|
_ http.Hijacker = (*bodyOverrider)(nil)
|
|
)
|
|
|
|
func (b *bodyOverrider) WriteHeader(code int) {
|
|
if !b.hijacked &&
|
|
b.Header().Get("Content-Type") == "text/plain; charset=utf-8" &&
|
|
(code == http.StatusNotFound || code == http.StatusMethodNotAllowed) {
|
|
|
|
b.Header().Set("Content-Type", "application/json")
|
|
b.override = true
|
|
}
|
|
|
|
b.code = code
|
|
b.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (b *bodyOverrider) Write(body []byte) (n int, err error) {
|
|
if b.override {
|
|
n = len(body)
|
|
if !b.written {
|
|
switch b.code {
|
|
case http.StatusNotFound:
|
|
_, err = b.ResponseWriter.Write(b.statusNotFoundBody)
|
|
case http.StatusMethodNotAllowed:
|
|
_, err = b.ResponseWriter.Write(b.statusMethodNotAllowedBody)
|
|
}
|
|
}
|
|
b.written = true
|
|
return
|
|
}
|
|
|
|
return b.ResponseWriter.Write(body)
|
|
}
|
|
|
|
func (b *bodyOverrider) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
hijacker, ok := b.ResponseWriter.(http.Hijacker)
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("HandleErrors: %T does not implement http.Hijacker", b.ResponseWriter)
|
|
}
|
|
b.hijacked = true
|
|
return hijacker.Hijack()
|
|
}
|
|
|
|
func (b *bodyOverrider) Flush() {
|
|
flusher, ok := b.ResponseWriter.(http.Flusher)
|
|
if ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|