277 lines
7.2 KiB
Go
277 lines
7.2 KiB
Go
package tokenConsumerUtil
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
type RedisTokenConsumerService struct {
|
||
redisClient *redis.Client
|
||
}
|
||
|
||
const (
|
||
TOKEN_BUCKET_KEY = "pdd:goods:add"
|
||
COUNTER_KEY_PREFIX = "counter:"
|
||
)
|
||
|
||
// Lua脚本:原子性地消费令牌
|
||
var consumeTokenScript = `
|
||
local currentTokens = tonumber(redis.call('get', KEYS[1]) or 0)
|
||
if currentTokens > 0 then
|
||
redis.call('decr', KEYS[1])
|
||
return 1
|
||
else
|
||
return 0
|
||
end`
|
||
|
||
func NewRedisTokenConsumerService(redisClient *redis.Client) *RedisTokenConsumerService {
|
||
return &RedisTokenConsumerService{
|
||
redisClient: redisClient,
|
||
}
|
||
}
|
||
|
||
// TryAcquireToken 尝试获取令牌
|
||
func (s *RedisTokenConsumerService) TryAcquireToken() bool {
|
||
ctx := context.Background()
|
||
|
||
script := redis.NewScript(consumeTokenScript)
|
||
|
||
result, err := script.Run(ctx, s.redisClient, []string{TOKEN_BUCKET_KEY}).Int64()
|
||
if err != nil {
|
||
// Redis操作异常,视为获取失败
|
||
return false
|
||
}
|
||
|
||
return result == 1
|
||
}
|
||
|
||
// AcquireTokenWithRetry 带重试的获取令牌
|
||
func (s *RedisTokenConsumerService) AcquireTokenWithRetry(maxWaitSeconds int) bool {
|
||
for i := 0; i < maxWaitSeconds; i++ {
|
||
if s.TryAcquireToken() {
|
||
return true
|
||
}
|
||
|
||
// 等待一段时间后重试(10ms)
|
||
time.Sleep(10 * time.Millisecond)
|
||
}
|
||
return false
|
||
}
|
||
|
||
// IncrementCounter 原子性地递增计数器
|
||
func (s *RedisTokenConsumerService) IncrementCounter(fixedParam string) int64 {
|
||
return s.IncrementCounterWithExpire(fixedParam, 120) // 默认过期时间120秒
|
||
}
|
||
|
||
// IncrementCounterWithExpire 原子性地递增计数器(带过期时间)
|
||
func (s *RedisTokenConsumerService) IncrementCounterWithExpire(fixedParam string, expireSeconds int) int64 {
|
||
ctx := context.Background()
|
||
|
||
// 获取当前秒级时间戳
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
// 使用管道确保原子性
|
||
pipe := s.redisClient.Pipeline()
|
||
incr := pipe.Incr(ctx, counterKey)
|
||
pipe.Expire(ctx, counterKey, time.Duration(expireSeconds)*time.Second)
|
||
|
||
_, err := pipe.Exec(ctx)
|
||
if err != nil {
|
||
// Redis操作异常,返回0
|
||
fmt.Printf("Redis error: %v\n", err)
|
||
return 0
|
||
}
|
||
|
||
return incr.Val()
|
||
}
|
||
|
||
// IncrementCounters 批量递增计数器
|
||
func (s *RedisTokenConsumerService) IncrementCounters(fixedParams []string, expireSeconds int) []int64 {
|
||
results := make([]int64, len(fixedParams))
|
||
for i, param := range fixedParams {
|
||
results[i] = s.IncrementCounterWithExpire(param, expireSeconds)
|
||
}
|
||
return results
|
||
}
|
||
|
||
// GetCurrentCounterValue 获取当前计数器的值(不递增)
|
||
func (s *RedisTokenConsumerService) GetCurrentCounterValue(fixedParam string) int64 {
|
||
ctx := context.Background()
|
||
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
val, err := s.redisClient.Get(ctx, counterKey).Result()
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
|
||
return s.convertToLong(val)
|
||
}
|
||
|
||
// buildCounterKey 构建计数器Redis key
|
||
func (s *RedisTokenConsumerService) buildCounterKey(timestamp int64, fixedParam string) string {
|
||
return COUNTER_KEY_PREFIX + fixedParam + ":" + strconv.FormatInt(timestamp, 10)
|
||
}
|
||
|
||
// GetCounterValueByTimestamp 获取指定时间戳和参数的计数器值
|
||
func (s *RedisTokenConsumerService) GetCounterValueByTimestamp(timestamp int64, fixedParam string) int64 {
|
||
ctx := context.Background()
|
||
|
||
counterKey := s.buildCounterKey(timestamp, fixedParam)
|
||
val, err := s.redisClient.Get(ctx, counterKey).Result()
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
|
||
return s.convertToLong(val)
|
||
}
|
||
|
||
// DeleteCounter 删除指定时间戳和参数的计数器
|
||
func (s *RedisTokenConsumerService) DeleteCounter(timestamp int64, fixedParam string) bool {
|
||
ctx := context.Background()
|
||
|
||
counterKey := s.buildCounterKey(timestamp, fixedParam)
|
||
|
||
result, err := s.redisClient.Del(ctx, counterKey).Result()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
return result > 0
|
||
}
|
||
|
||
// DeleteCurrentCounter 删除当前时间的计数器
|
||
func (s *RedisTokenConsumerService) DeleteCurrentCounter(fixedParam string) bool {
|
||
currentSecond := time.Now().Unix()
|
||
return s.DeleteCounter(currentSecond, fixedParam)
|
||
}
|
||
|
||
// convertToLong 将字符串转换为int64类型
|
||
func (s *RedisTokenConsumerService) convertToLong(value string) int64 {
|
||
if value == "" {
|
||
return 0
|
||
}
|
||
|
||
result, err := strconv.ParseInt(value, 10, 64)
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// SetCounterExpire 设置计数器的过期时间
|
||
func (s *RedisTokenConsumerService) SetCounterExpire(fixedParam string, expireSeconds int) bool {
|
||
ctx := context.Background()
|
||
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
result, err := s.redisClient.Expire(ctx, counterKey, time.Duration(expireSeconds)*time.Second).Result()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// GetCounterTtl 获取计数器的剩余过期时间
|
||
func (s *RedisTokenConsumerService) GetCounterTtl(fixedParam string) time.Duration {
|
||
ctx := context.Background()
|
||
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
ttl, err := s.redisClient.TTL(ctx, counterKey).Result()
|
||
if err != nil {
|
||
return -2
|
||
}
|
||
|
||
return ttl
|
||
}
|
||
|
||
// DebugCounter 调试方法:检查key的详细信息
|
||
func (s *RedisTokenConsumerService) DebugCounter(fixedParam string) {
|
||
ctx := context.Background()
|
||
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
ttl, err := s.redisClient.TTL(ctx, counterKey).Result()
|
||
value, err2 := s.redisClient.Get(ctx, counterKey).Result()
|
||
exists, err3 := s.redisClient.Exists(ctx, counterKey).Result()
|
||
|
||
fmt.Println("=== Counter Debug Info ===")
|
||
fmt.Println("Key:", counterKey)
|
||
fmt.Println("Exists:", exists > 0)
|
||
fmt.Println("Value:", value)
|
||
fmt.Println("TTL:", ttl)
|
||
fmt.Println("Is permanent:", ttl == -1)
|
||
|
||
if err != nil || err2 != nil || err3 != nil {
|
||
fmt.Println("Debug error occurred")
|
||
}
|
||
fmt.Println("==========================")
|
||
}
|
||
|
||
// CounterExists 检查计数器是否存在
|
||
func (s *RedisTokenConsumerService) CounterExists(fixedParam string) bool {
|
||
ctx := context.Background()
|
||
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
result, err := s.redisClient.Exists(ctx, counterKey).Result()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
return result > 0
|
||
}
|
||
|
||
// ResetCounter 重置计数器(删除并重新创建)
|
||
func (s *RedisTokenConsumerService) ResetCounter(fixedParam string, expireSeconds int) int64 {
|
||
ctx := context.Background()
|
||
|
||
currentSecond := time.Now().Unix()
|
||
counterKey := s.buildCounterKey(currentSecond, fixedParam)
|
||
|
||
// 先删除
|
||
s.redisClient.Del(ctx, counterKey)
|
||
|
||
// 重新创建并设置过期时间
|
||
err := s.redisClient.Set(ctx, counterKey, 1, time.Duration(expireSeconds)*time.Second).Err()
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
|
||
return 1
|
||
}
|
||
|
||
// 使用示例
|
||
func main() {
|
||
// 初始化Redis客户端
|
||
rdb := redis.NewClient(&redis.Options{
|
||
Addr: "localhost:6379",
|
||
Password: "", // 无密码
|
||
DB: 0, // 使用默认DB
|
||
})
|
||
|
||
// 创建服务实例
|
||
service := NewRedisTokenConsumerService(rdb)
|
||
|
||
// 测试令牌获取
|
||
success := service.TryAcquireToken()
|
||
fmt.Printf("Token acquired: %v\n", success)
|
||
|
||
// 测试计数器
|
||
count := service.IncrementCounter("test_param")
|
||
fmt.Printf("Counter value: %d\n", count)
|
||
}
|