261 lines
5.8 KiB
Go
261 lines
5.8 KiB
Go
package middleware
|
||
|
||
import (
|
||
"crypto/md5"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"net/http"
|
||
"psi/config"
|
||
"regexp"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// APISign API签名验证中间件
|
||
func APISign() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 获取签名相关参数
|
||
appKey := c.Query("app_key")
|
||
clientId := c.Query("client_id")
|
||
sign := c.Query("sign")
|
||
signMethod := c.Query("sign_method")
|
||
timestamp := c.Query("timestamp")
|
||
if appKey == "" {
|
||
appKey = c.PostForm("app_key")
|
||
}
|
||
if clientId == "" {
|
||
clientId = c.PostForm("client_id")
|
||
}
|
||
if sign == "" {
|
||
sign = c.PostForm("sign")
|
||
}
|
||
if signMethod == "" {
|
||
signMethod = c.PostForm("sign_method")
|
||
}
|
||
if timestamp == "" {
|
||
timestamp = c.PostForm("timestamp")
|
||
}
|
||
// 验证必填参数
|
||
if appKey == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "缺少app_key参数",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
if clientId == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "缺少client_id参数",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
if sign == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "缺少sign参数",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
if timestamp == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "缺少timestamp参数",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 验证app_key
|
||
if appKey != config.AppConfig.APISign.AppKey {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "无效的app_key",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 验证client_id
|
||
if clientId != config.AppConfig.APISign.ClientId {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "无效的client_id",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 验证时间戳(防止重放攻击)
|
||
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "时间戳格式错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
currentTime := time.Now().Unix()
|
||
timeDiff := currentTime - ts
|
||
|
||
// 检查时间戳是否在允许范围内(默认300秒)
|
||
tolerance := int64(config.AppConfig.APISign.TimestampTolerance)
|
||
if tolerance <= 0 {
|
||
tolerance = 300 // 默认5分钟
|
||
}
|
||
|
||
if timeDiff > tolerance || timeDiff < -tolerance {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "请求已过期,请检查系统时间",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 验证签名方法
|
||
if signMethod == "" {
|
||
signMethod = config.AppConfig.APISign.SignMethod
|
||
}
|
||
if signMethod != "md5" && signMethod != "sha256" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "不支持的签名方法",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 计算签名
|
||
calculatedSign, err := calculateSign(c, signMethod)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": "签名计算失败",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
// 验证签名
|
||
if calculatedSign != sign {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "签名验证失败",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// calculateSign 计算签名
|
||
func calculateSign(c *gin.Context, signMethod string) (string, error) {
|
||
// 收集所有参数
|
||
params := make(map[string]string)
|
||
|
||
// 获取URL查询参数
|
||
for key, values := range c.Request.URL.Query() {
|
||
// 跳过sign参数本身
|
||
if key == "sign" {
|
||
continue
|
||
}
|
||
if len(values) > 0 {
|
||
params[key] = values[0]
|
||
}
|
||
}
|
||
|
||
// 如果是POST、PUT请求,获取表单参数
|
||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
|
||
c.Request.ParseForm()
|
||
for key, values := range c.Request.PostForm {
|
||
// 跳过sign参数本身
|
||
if key == "sign" {
|
||
continue
|
||
}
|
||
if len(values) > 0 {
|
||
// 如果是数组参数,用逗号拼接所有值
|
||
if len(values) > 1 {
|
||
params[key] = strings.Join(values, ",")
|
||
} else {
|
||
params[key] = values[0]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// 提取所有键并排序
|
||
keys := make([]string, 0, len(params))
|
||
for key := range params {
|
||
keys = append(keys, key)
|
||
}
|
||
sortKeysWithIndex(keys)
|
||
|
||
// 构建签名字符串:key1=value1&key2=value2...
|
||
var signStrings []string
|
||
for _, key := range keys {
|
||
signStrings = append(signStrings, fmt.Sprintf("%s=%s", key, params[key]))
|
||
}
|
||
signStr := strings.Join(signStrings, "&")
|
||
//fmt.Println("signStr: %s", signStr)
|
||
// 在字符串前后添加secret
|
||
appSecret := config.AppConfig.APISign.AppSecret
|
||
signStr = appSecret + signStr + appSecret
|
||
// 根据签名方法计算哈希
|
||
var hash []byte
|
||
switch signMethod {
|
||
case "md5":
|
||
md5Hash := md5.Sum([]byte(signStr))
|
||
hash = md5Hash[:]
|
||
case "sha256":
|
||
sha256Hash := sha256.Sum256([]byte(signStr))
|
||
hash = sha256Hash[:]
|
||
default:
|
||
return "", fmt.Errorf("不支持的签名方法: %s", signMethod)
|
||
}
|
||
|
||
// 转换为大写十六进制字符串
|
||
return strings.ToUpper(hex.EncodeToString(hash)), nil
|
||
}
|
||
|
||
// sortKeysWithIndex 对参数键进行智能排序,处理带索引的数组参数
|
||
func sortKeysWithIndex(keys []string) {
|
||
// 正则表达式匹配带索引的键,如 items[0][product_id]
|
||
indexPattern := regexp.MustCompile(`^(.+)\[(\d+)\](.*)$`)
|
||
|
||
sort.Slice(keys, func(i, j int) bool {
|
||
keyI := keys[i]
|
||
keyJ := keys[j]
|
||
|
||
matchI := indexPattern.FindStringSubmatch(keyI)
|
||
matchJ := indexPattern.FindStringSubmatch(keyJ)
|
||
|
||
// 如果两个都是带索引的键
|
||
if matchI != nil && matchJ != nil {
|
||
prefixI, idxIStr, suffixI := matchI[1], matchI[2], matchI[3]
|
||
prefixJ, idxJStr, suffixJ := matchJ[1], matchJ[2], matchJ[3]
|
||
|
||
// 先比较前缀(如 items)
|
||
if prefixI != prefixJ {
|
||
return prefixI < prefixJ
|
||
}
|
||
|
||
// 前缀相同,比较索引(数值比较)
|
||
idxI, errI := strconv.Atoi(idxIStr)
|
||
idxJ, errJ := strconv.Atoi(idxJStr)
|
||
|
||
if errI == nil && errJ == nil {
|
||
if idxI != idxJ {
|
||
return idxI < idxJ
|
||
}
|
||
// 索引相同,比较后缀(如 [product_id])
|
||
return suffixI < suffixJ
|
||
}
|
||
}
|
||
|
||
// 至少有一个不是带索引的键,使用默认字典序
|
||
return keyI < keyJ
|
||
})
|
||
}
|