diff --git a/internal/base/middleware/recovery.go b/internal/base/middleware/recovery.go new file mode 100644 index 000000000..02b1cbde7 --- /dev/null +++ b/internal/base/middleware/recovery.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package middleware + +import ( + "net/http" + "runtime/debug" + "strings" + + "github.com/apache/answer/internal/base/handler" + "github.com/apache/answer/internal/base/reason" + "github.com/gin-gonic/gin" + "github.com/segmentfault/pacman/log" +) + +func Recovery(apiPrefixes ...string) gin.HandlerFunc { + return func(ctx *gin.Context) { + defer func() { + if err := recover(); err != nil { + log.Errorf("panic recovered: %v\n%s", err, debug.Stack()) + + // Headers/body already flushed (SSE or any streamed response). + // We can no longer rewrite the response cleanly; just stop the chain. + if ctx.Writer.Written() { + ctx.Abort() + return + } + + path := ctx.Request.URL.Path + for _, p := range apiPrefixes { + if strings.HasPrefix(path, p) { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, + handler.NewRespBody(http.StatusInternalServerError, reason.UnknownError). + TrMsg(handler.GetLangByCtx(ctx)), + ) + return + } + } + + ctx.AbortWithStatus(http.StatusInternalServerError) + } + }() + ctx.Next() + } +} diff --git a/internal/base/middleware/recovery_test.go b/internal/base/middleware/recovery_test.go new file mode 100644 index 000000000..58dbfdb0b --- /dev/null +++ b/internal/base/middleware/recovery_test.go @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +// Panic on an API path returns the project's unified JSON 500. +func TestRecovery_APIPathPanic(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Recovery("/api")) + r.GET("/api/panic", func(ctx *gin.Context) { + panic("test panic") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/panic", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } + + var body map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + if body["reason"] != "base.unknown" { + t.Errorf("unexpected reason: %v", body["reason"]) + } +} + +// Panic on a non-API path returns a bare 500 with no body, so the browser can +// render its own error page instead of showing raw JSON. +func TestRecovery_NonAPIPathPanic(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Recovery("/api")) + r.GET("/page", func(ctx *gin.Context) { + panic("test panic") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/page", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } + if w.Body.Len() != 0 { + t.Errorf("expected empty body for non-API path, got: %q", w.Body.String()) + } +} + +// Panic after the response has already started writing (SSE / streamed +// responses). The middleware must not touch the response — status and body +// already on the wire stay untouched, no JSON gets appended. +func TestRecovery_PanicAfterResponseStarted(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Recovery("/api")) + r.GET("/api/stream", func(ctx *gin.Context) { + ctx.Writer.WriteHeader(http.StatusOK) + _, _ = ctx.Writer.Write([]byte("partial data")) + panic("test panic after write") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/stream", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status to remain 200 (already flushed), got %d", w.Code) + } + if w.Body.String() != "partial data" { + t.Errorf("expected body to remain 'partial data' (no error JSON appended), got: %q", w.Body.String()) + } +} + +// Normal requests pass through unaffected. +func TestRecovery_NoPanic(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Recovery("/api")) + r.GET("/api/ok", func(ctx *gin.Context) { + ctx.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/ok", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} diff --git a/internal/base/server/http.go b/internal/base/server/http.go index 765cbf6be..8db557440 100644 --- a/internal/base/server/http.go +++ b/internal/base/server/http.go @@ -52,6 +52,10 @@ func NewHTTPServer(debug bool, gin.SetMode(gin.ReleaseMode) } r := gin.New() + r.Use(middleware.Recovery( + uiConf.APIBaseURL+"/answer/api/v1", + uiConf.APIBaseURL+"/answer/admin/api", + )) r.Use(func(ctx *gin.Context) { if strings.Contains(ctx.Request.URL.Path, "/chat/completions") { return