This commit is contained in:
vipg
2026-02-09 16:32:14 +08:00
parent 134ece2bcc
commit 2efc23cac7
2 changed files with 31 additions and 26 deletions

View File

@@ -0,0 +1,6 @@
package types
type Response struct {
Status bool `json:"status"`
Message string `json:"message"`
}

View File

@@ -17,6 +17,7 @@ import (
"common/db" "common/db"
"common/logger" "common/logger"
"common/utils" "common/utils"
"common/types"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
) )
@@ -95,28 +96,26 @@ type loginResp struct {
func registerHandler(w http.ResponseWriter, r *http.Request) { func registerHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed) writeJSON(w, http.StatusMethodNotAllowed, false, "method not allowed")
return return
} }
var req registerReq var req registerReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest) writeJSON(w, http.StatusBadRequest, false, "invalid json")
w.Write([]byte("invalid json"))
return return
} }
if !validAccount(req.Account) || !validPassword(req.Password) { if !validAccount(req.Account) || !validPassword(req.Password) {
w.WriteHeader(http.StatusBadRequest) writeJSON(w, http.StatusBadRequest, false, "invalid account or password")
w.Write([]byte("invalid account or password"))
return return
} }
hashed, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12) hashed, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) writeJSON(w, http.StatusInternalServerError, false, "internal error")
return return
} }
tx, err := pg.Begin() tx, err := pg.Begin()
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) writeJSON(w, http.StatusInternalServerError, false, "internal error")
return return
} }
defer func() { defer func() {
@@ -124,63 +123,57 @@ func registerHandler(w http.ResponseWriter, r *http.Request) {
}() }()
var userID string var userID string
if err := tx.QueryRow(`INSERT INTO users DEFAULT VALUES RETURNING user_id`).Scan(&userID); err != nil { if err := tx.QueryRow(`INSERT INTO users DEFAULT VALUES RETURNING user_id`).Scan(&userID); err != nil {
w.WriteHeader(http.StatusInternalServerError) writeJSON(w, http.StatusInternalServerError, false, "internal error")
return return
} }
if _, err := tx.Exec(`INSERT INTO user_login_accounts (user_id, value, deleted) VALUES ($1, $2, false)`, userID, req.Account); err != nil { if _, err := tx.Exec(`INSERT INTO user_login_accounts (user_id, value, deleted) VALUES ($1, $2, false)`, userID, req.Account); err != nil {
if isUniqueViolation(err) { if isUniqueViolation(err) {
w.WriteHeader(http.StatusConflict) writeJSON(w, http.StatusConflict, false, "account exists")
w.Write([]byte("account exists"))
return return
} }
w.WriteHeader(http.StatusInternalServerError) writeJSON(w, http.StatusInternalServerError, false, "internal error")
return return
} }
if _, err := tx.Exec(`INSERT INTO user_login_passwords (user_id, value, deleted) VALUES ($1, $2, false)`, userID, string(hashed)); err != nil { if _, err := tx.Exec(`INSERT INTO user_login_passwords (user_id, value, deleted) VALUES ($1, $2, false)`, userID, string(hashed)); err != nil {
w.WriteHeader(http.StatusInternalServerError) writeJSON(w, http.StatusInternalServerError, false, "internal error")
return return
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
w.WriteHeader(http.StatusInternalServerError) writeJSON(w, http.StatusInternalServerError, false, "internal error")
return return
} }
w.WriteHeader(http.StatusCreated) writeJSON(w, http.StatusCreated, true, userID)
io.WriteString(w, userID)
} }
func loginHandler(w http.ResponseWriter, r *http.Request) { func loginHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed) writeJSON(w, http.StatusMethodNotAllowed, false, "method not allowed")
return return
} }
var req loginReq var req loginReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest) writeJSON(w, http.StatusBadRequest, false, "invalid json")
w.Write([]byte("invalid json"))
return return
} }
if !validAccount(req.Account) || !validPassword(req.Password) { if !validAccount(req.Account) || !validPassword(req.Password) {
w.WriteHeader(http.StatusBadRequest) writeJSON(w, http.StatusBadRequest, false, "invalid account or password")
w.Write([]byte("invalid account or password"))
return return
} }
var userID string var userID string
if err := pg.QueryRow(`SELECT user_id FROM user_login_accounts WHERE value = $1 AND deleted = false`, req.Account).Scan(&userID); err != nil { if err := pg.QueryRow(`SELECT user_id FROM user_login_accounts WHERE value = $1 AND deleted = false`, req.Account).Scan(&userID); err != nil {
w.WriteHeader(http.StatusUnauthorized) writeJSON(w, http.StatusUnauthorized, false, "unauthorized")
return return
} }
var hashed string var hashed string
if err := pg.QueryRow(`SELECT value FROM user_login_passwords WHERE user_id = $1 AND deleted = false`, userID).Scan(&hashed); err != nil { if err := pg.QueryRow(`SELECT value FROM user_login_passwords WHERE user_id = $1 AND deleted = false`, userID).Scan(&hashed); err != nil {
w.WriteHeader(http.StatusUnauthorized) writeJSON(w, http.StatusUnauthorized, false, "unauthorized")
return return
} }
if bcrypt.CompareHashAndPassword([]byte(hashed), []byte(req.Password)) != nil { if bcrypt.CompareHashAndPassword([]byte(hashed), []byte(req.Password)) != nil {
w.WriteHeader(http.StatusUnauthorized) writeJSON(w, http.StatusUnauthorized, false, "unauthorized")
return return
} }
w.Header().Set("Content-Type", "application/json") writeJSON(w, http.StatusOK, true, userID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(loginResp{UserID: userID})
} }
func validAccount(a string) bool { func validAccount(a string) bool {
@@ -211,3 +204,9 @@ func applySchema(path string) error {
_, err = pg.Exec(string(b)) _, err = pg.Exec(string(b))
return err return err
} }
func writeJSON(w http.ResponseWriter, code int, status bool, msg string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(types.Response{Status: status, Message: msg})
}