diff --git a/trading_assistant_api/common/httpx/auth.go b/trading_assistant_api/common/httpx/auth.go index 784f464..1770a6b 100644 --- a/trading_assistant_api/common/httpx/auth.go +++ b/trading_assistant_api/common/httpx/auth.go @@ -6,6 +6,7 @@ import ( "strings" "common/auth" + "common/logger" ) type userIDKey struct{} @@ -15,13 +16,15 @@ func AuthRequired() func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ah := r.Header.Get("Authorization") if ah == "" || !strings.HasPrefix(ah, "Bearer ") { - Unauthorized(w, "unauthorized") + logger.WithPrefix("rid="+RequestIDFromContext(r)).Printf("auth missing header path=%s", r.URL.Path) + Unauthorized(w, r, "unauthorized") return } token := strings.TrimSpace(strings.TrimPrefix(ah, "Bearer ")) sub, err := auth.ParseToken(token) if err != nil || sub == "" { - Unauthorized(w, "unauthorized") + logger.WithPrefix("rid="+RequestIDFromContext(r)).Printf("auth invalid token path=%s", r.URL.Path) + Unauthorized(w, r, "unauthorized") return } ctx := context.WithValue(r.Context(), userIDKey{}, sub) diff --git a/trading_assistant_api/common/httpx/httpx.go b/trading_assistant_api/common/httpx/httpx.go index 915608a..3bb1357 100644 --- a/trading_assistant_api/common/httpx/httpx.go +++ b/trading_assistant_api/common/httpx/httpx.go @@ -14,30 +14,42 @@ func WriteJSON(w http.ResponseWriter, status int, ok bool, msg string, data inte json.NewEncoder(w).Encode(types.Response{Status: ok, Message: msg, Data: data}) } -func OK(w http.ResponseWriter, data interface{}) { - WriteJSON(w, http.StatusOK, true, string(codes.OK), data) +func OK(w http.ResponseWriter, r *http.Request, data interface{}) { + WriteJSON(w, http.StatusOK, true, string(codes.OK), addRequestID(r, data)) } -func Created(w http.ResponseWriter, data interface{}) { - WriteJSON(w, http.StatusCreated, true, string(codes.OK), data) +func Created(w http.ResponseWriter, r *http.Request, data interface{}) { + WriteJSON(w, http.StatusCreated, true, string(codes.OK), addRequestID(r, data)) } -func BadRequest(w http.ResponseWriter, msg string) { - WriteJSON(w, http.StatusBadRequest, false, msg, map[string]string{"code": string(codes.InvalidInput)}) +func BadRequest(w http.ResponseWriter, r *http.Request, msg string) { + WriteJSON(w, http.StatusBadRequest, false, msg, addRequestID(r, map[string]string{"code": string(codes.InvalidInput)})) } -func Unauthorized(w http.ResponseWriter, msg string) { - WriteJSON(w, http.StatusUnauthorized, false, msg, map[string]string{"code": string(codes.Unauthorized)}) +func Unauthorized(w http.ResponseWriter, r *http.Request, msg string) { + WriteJSON(w, http.StatusUnauthorized, false, msg, addRequestID(r, map[string]string{"code": string(codes.Unauthorized)})) } -func Conflict(w http.ResponseWriter, msg string) { - WriteJSON(w, http.StatusConflict, false, msg, map[string]string{"code": string(codes.Conflict)}) +func Conflict(w http.ResponseWriter, r *http.Request, msg string) { + WriteJSON(w, http.StatusConflict, false, msg, addRequestID(r, map[string]string{"code": string(codes.Conflict)})) } -func MethodNotAllowed(w http.ResponseWriter, msg string) { - WriteJSON(w, http.StatusMethodNotAllowed, false, msg, map[string]string{"code": string(codes.MethodNotAllowed)}) +func MethodNotAllowed(w http.ResponseWriter, r *http.Request, msg string) { + WriteJSON(w, http.StatusMethodNotAllowed, false, msg, addRequestID(r, map[string]string{"code": string(codes.MethodNotAllowed)})) } -func InternalError(w http.ResponseWriter) { - WriteJSON(w, http.StatusInternalServerError, false, string(codes.InternalError), map[string]string{"code": string(codes.InternalError)}) +func InternalError(w http.ResponseWriter, r *http.Request) { + WriteJSON(w, http.StatusInternalServerError, false, string(codes.InternalError), addRequestID(r, map[string]string{"code": string(codes.InternalError)})) +} + +func addRequestID(r *http.Request, data interface{}) interface{} { + rid := RequestIDFromContext(r) + if m, ok := data.(map[string]string); ok { + m["request_id"] = rid + return m + } + if data == nil { + return map[string]string{"request_id": rid} + } + return map[string]interface{}{"request_id": rid, "data": data} } diff --git a/trading_assistant_api/common/logger/logger.go b/trading_assistant_api/common/logger/logger.go index 371da7e..53e4854 100644 --- a/trading_assistant_api/common/logger/logger.go +++ b/trading_assistant_api/common/logger/logger.go @@ -3,6 +3,7 @@ package logger import ( "log" "os" + "fmt" ) type Logger interface { @@ -30,3 +31,18 @@ func SetLogger(l Logger) { defaultLogger = l } } + +type prefLogger struct { + prefix string +} + +func (p *prefLogger) Printf(format string, v ...any) { + defaultLogger.Printf("%s %s", p.prefix, fmt.Sprintf(format, v...)) +} +func (p *prefLogger) Fatalf(format string, v ...any) { + defaultLogger.Fatalf("%s %s", p.prefix, fmt.Sprintf(format, v...)) +} + +func WithPrefix(prefix string) Logger { + return &prefLogger{prefix: prefix} +}