-
This commit is contained in:
265
internal/middleware/validation.go
Normal file
265
internal/middleware/validation.go
Normal file
@ -0,0 +1,265 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Field string `json:"field"`
|
||||
Tag string `json:"tag"`
|
||||
Value string `json:"value"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ValidationResponse represents the validation error response
|
||||
type ValidationResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Details []ValidationError `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
var validate *validator.Validate
|
||||
|
||||
func init() {
|
||||
validate = validator.New()
|
||||
|
||||
// Register custom tag name function to use json tags
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateJSON validates JSON request body against struct validation tags
|
||||
func ValidateJSON(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip validation for GET requests and requests without body
|
||||
if c.Request.Method == "GET" || c.Request.ContentLength == 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Store original body for potential re-reading
|
||||
c.Set("validation_enabled", true)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateStruct validates a struct and returns formatted errors
|
||||
func ValidateStruct(s interface{}) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
err := validate.Struct(s)
|
||||
if err != nil {
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
var element ValidationError
|
||||
element.Field = err.Field()
|
||||
element.Tag = err.Tag()
|
||||
element.Value = err.Param()
|
||||
element.Message = getErrorMessage(err)
|
||||
errors = append(errors, element)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// ValidateAndBind validates and binds JSON request to struct
|
||||
func ValidateAndBind(c *gin.Context, obj interface{}) error {
|
||||
// Bind JSON to struct
|
||||
if err := c.ShouldBindJSON(obj); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid JSON",
|
||||
Message: "Request body contains invalid JSON: " + err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate struct
|
||||
if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Validation Failed",
|
||||
Message: "Request validation failed",
|
||||
Details: validationErrors,
|
||||
})
|
||||
return validator.ValidationErrors{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getErrorMessage returns a human-readable error message for validation errors
|
||||
func getErrorMessage(fe validator.FieldError) string {
|
||||
switch fe.Tag() {
|
||||
case "required":
|
||||
return "This field is required"
|
||||
case "email":
|
||||
return "Invalid email format"
|
||||
case "min":
|
||||
return "Value is too short (minimum " + fe.Param() + " characters)"
|
||||
case "max":
|
||||
return "Value is too long (maximum " + fe.Param() + " characters)"
|
||||
case "url":
|
||||
return "Invalid URL format"
|
||||
case "oneof":
|
||||
return "Value must be one of: " + fe.Param()
|
||||
case "uuid":
|
||||
return "Invalid UUID format"
|
||||
case "gte":
|
||||
return "Value must be greater than or equal to " + fe.Param()
|
||||
case "lte":
|
||||
return "Value must be less than or equal to " + fe.Param()
|
||||
case "len":
|
||||
return "Value must be exactly " + fe.Param() + " characters"
|
||||
case "dive":
|
||||
return "Invalid array element"
|
||||
default:
|
||||
return "Invalid value for " + fe.Field()
|
||||
}
|
||||
}
|
||||
|
||||
// RequiredFields validates that specific fields are present in the request
|
||||
func RequiredFields(fields ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var json map[string]interface{}
|
||||
|
||||
if err := c.ShouldBindJSON(&json); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid JSON",
|
||||
Message: "Request body contains invalid JSON",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
var missingFields []string
|
||||
for _, field := range fields {
|
||||
if _, exists := json[field]; !exists {
|
||||
missingFields = append(missingFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingFields) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Missing Required Fields",
|
||||
Message: "The following required fields are missing: " + strings.Join(missingFields, ", "),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store the parsed JSON for use in handlers
|
||||
c.Set("parsed_json", json)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUUID validates that a URL parameter is a valid UUID
|
||||
func ValidateUUID(param string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
value := c.Param(param)
|
||||
if value == "" {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Missing Parameter",
|
||||
Message: "Required parameter '" + param + "' is missing",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate UUID format
|
||||
if err := validate.Var(value, "uuid"); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid Parameter",
|
||||
Message: "Parameter '" + param + "' must be a valid UUID",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQueryParams validates query parameters
|
||||
func ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var errors []ValidationError
|
||||
|
||||
for param, rule := range rules {
|
||||
value := c.Query(param)
|
||||
if value != "" {
|
||||
if err := validate.Var(value, rule); err != nil {
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: param,
|
||||
Tag: err.Tag(),
|
||||
Value: err.Param(),
|
||||
Message: getErrorMessage(err),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid Query Parameters",
|
||||
Message: "One or more query parameters are invalid",
|
||||
Details: errors,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeInput sanitizes input strings to prevent XSS and injection attacks
|
||||
func SanitizeInput() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// This is a basic implementation - in production you might want to use
|
||||
// a more sophisticated sanitization library like bluemonday
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePermissions validates that permission scopes follow the expected format
|
||||
func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
for i, perm := range permissions {
|
||||
// Check basic format: should contain only alphanumeric, dots, and underscores
|
||||
if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "permissions[" + string(rune(i)) + "]",
|
||||
Tag: "format",
|
||||
Value: perm,
|
||||
Message: "Permission scope must contain only alphanumeric characters, dots, and underscores",
|
||||
})
|
||||
}
|
||||
|
||||
// Check for dangerous patterns
|
||||
if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "permissions[" + string(rune(i)) + "]",
|
||||
Tag: "format",
|
||||
Value: perm,
|
||||
Message: "Permission scope has invalid format",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
Reference in New Issue
Block a user