448 lines
10 KiB
Go
448 lines
10 KiB
Go
package utils
|
||
|
||
import (
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
"unicode"
|
||
)
|
||
|
||
// PageQuery 分页查询实体类
|
||
type PageQuery struct {
|
||
PageSize *int `json:"pageSize" form:"pageSize"` // 分页大小
|
||
PageNum *int `json:"pageNum" form:"pageNum"` // 当前页数
|
||
OrderByColumn string `json:"orderByColumn" form:"orderByColumn"` // 排序列
|
||
IsAsc string `json:"isAsc" form:"isAsc"` // 排序的方向desc或者asc
|
||
|
||
// 常量定义
|
||
DefaultPageNum int // 当前记录起始索引 默认值
|
||
DefaultPageSize int // 每页显示记录数 默认值
|
||
MaxPageSize int // 最大分页大小限制
|
||
MinSize int // 最小分页大小
|
||
}
|
||
|
||
// NewPageQuery 创建新的PageQuery实例
|
||
func NewPageQuery() *PageQuery {
|
||
return &PageQuery{
|
||
DefaultPageNum: 1,
|
||
DefaultPageSize: 20,
|
||
MaxPageSize: 1000,
|
||
MinSize: 0,
|
||
}
|
||
}
|
||
|
||
// WithParams 设置分页参数
|
||
func (p *PageQuery) WithParams(pageNum, pageSize *int) *PageQuery {
|
||
p.PageNum = pageNum
|
||
p.PageSize = pageSize
|
||
return p
|
||
}
|
||
|
||
// WithOrder 设置排序参数
|
||
func (p *PageQuery) WithOrder(orderByColumn, isAsc string) *PageQuery {
|
||
p.OrderByColumn = orderByColumn
|
||
p.IsAsc = isAsc
|
||
return p
|
||
}
|
||
|
||
// BuildSQL 构建分页查询的SQL片段
|
||
// 返回: ORDER BY 子句, LIMIT OFFSET 子句, 参数值, 错误
|
||
func (p *PageQuery) BuildSQL() (orderBy string, limitOffset string, args []interface{}, err error) {
|
||
// 设置分页参数
|
||
pageNum := p.getPageNum()
|
||
pageSize := p.getPageSize()
|
||
|
||
// 验证分页参数
|
||
if pageNum <= 0 {
|
||
pageNum = p.DefaultPageNum
|
||
}
|
||
if pageSize <= 0 {
|
||
pageSize = p.DefaultPageSize
|
||
}
|
||
// 限制最大分页大小
|
||
if pageSize > p.MaxPageSize {
|
||
return "", "", nil, errors.New(fmt.Sprintf("分页大小不能超过 %d", p.MaxPageSize))
|
||
}
|
||
|
||
// 计算偏移量
|
||
offset := (pageNum - 1) * pageSize
|
||
|
||
// 构建排序条件
|
||
if p.OrderByColumn != "" && p.IsAsc != "" {
|
||
orderBy, err = p.buildOrder()
|
||
if err != nil {
|
||
return "", "", nil, err
|
||
}
|
||
if orderBy != "" {
|
||
orderBy = "ORDER BY " + orderBy
|
||
}
|
||
}
|
||
|
||
// 构建分页条件
|
||
if pageSize > 0 {
|
||
limitOffset = "LIMIT ? OFFSET ?"
|
||
args = []interface{}{pageSize, offset}
|
||
}
|
||
|
||
return orderBy, limitOffset, args, nil
|
||
}
|
||
|
||
// BuildCountSQL 构建统计总数的SQL
|
||
func (p *PageQuery) BuildCountSQL(baseSQL string) (string, []interface{}) {
|
||
// 将SELECT语句转换为COUNT语句
|
||
lowerSQL := strings.ToLower(baseSQL)
|
||
|
||
// 移除ORDER BY子句(对于COUNT不需要排序)
|
||
if idx := strings.Index(lowerSQL, "order by"); idx != -1 {
|
||
baseSQL = baseSQL[:idx]
|
||
}
|
||
|
||
// 将SELECT ... FROM 替换为 COUNT(*)
|
||
selectIdx := strings.Index(lowerSQL, "select")
|
||
fromIdx := strings.Index(lowerSQL, "from")
|
||
|
||
if selectIdx != -1 && fromIdx != -1 {
|
||
countSQL := fmt.Sprintf("SELECT COUNT(*) %s", baseSQL[fromIdx:])
|
||
return countSQL, nil
|
||
}
|
||
|
||
// 如果无法处理,返回原样
|
||
return baseSQL, nil
|
||
}
|
||
|
||
// BuildPage 构建分页对象(返回通用的分页结构)
|
||
func (p *PageQuery) BuildPage() (*PageParam, error) {
|
||
// 设置合理的默认值
|
||
if p.DefaultPageNum <= 0 {
|
||
p.DefaultPageNum = 1
|
||
}
|
||
if p.DefaultPageSize <= 0 {
|
||
p.DefaultPageSize = 20
|
||
}
|
||
if p.MaxPageSize <= 0 {
|
||
p.MaxPageSize = 1000 // 设置默认最大值
|
||
}
|
||
|
||
pageNum := p.getPageNum()
|
||
pageSize := p.getPageSize()
|
||
|
||
// 验证
|
||
if pageNum <= 0 {
|
||
pageNum = p.DefaultPageNum
|
||
}
|
||
if pageSize <= 0 {
|
||
pageSize = p.DefaultPageSize
|
||
}
|
||
|
||
// 添加最小分页大小限制
|
||
if pageSize < 1 {
|
||
pageSize = p.DefaultPageSize
|
||
}
|
||
|
||
if pageSize > p.MaxPageSize {
|
||
// 提供更友好的错误信息
|
||
return nil, fmt.Errorf("分页大小 %d 超过最大限制 %d", pageSize, p.MaxPageSize)
|
||
}
|
||
|
||
offset := (pageNum - 1) * pageSize
|
||
|
||
return &PageParam{
|
||
PageNum: pageNum,
|
||
PageSize: pageSize,
|
||
Offset: offset,
|
||
OrderBy: p.OrderByColumn,
|
||
IsAsc: p.IsAsc,
|
||
}, nil
|
||
}
|
||
|
||
// GetFirstNum 获取起始记录索引
|
||
func (p *PageQuery) GetFirstNum() int {
|
||
pageNum := p.getPageNum()
|
||
pageSize := p.getPageSize()
|
||
return (pageNum - 1) * pageSize
|
||
}
|
||
|
||
// GetPageNum 获取页码
|
||
func (p *PageQuery) GetPageNum() int {
|
||
return p.getPageNum()
|
||
}
|
||
|
||
// GetPageSize 获取分页大小
|
||
func (p *PageQuery) GetPageSize() int {
|
||
return p.getPageSize()
|
||
}
|
||
|
||
// getPageNum 获取页码(内部方法)
|
||
func (p *PageQuery) getPageNum() int {
|
||
if p.PageNum == nil || *p.PageNum <= 0 {
|
||
return p.DefaultPageNum
|
||
}
|
||
return *p.PageNum
|
||
}
|
||
|
||
// getPageSize 获取分页大小(内部方法)
|
||
func (p *PageQuery) getPageSize() int {
|
||
if p.PageSize == nil || *p.PageSize <= 0 {
|
||
return p.DefaultPageSize
|
||
}
|
||
return *p.PageSize
|
||
}
|
||
|
||
// buildOrder 构建排序条件
|
||
func (p *PageQuery) buildOrder() (string, error) {
|
||
orderBy := p.escapeOrderBySql(p.OrderByColumn)
|
||
orderBy = p.toUnderScoreCase(orderBy)
|
||
|
||
// 兼容前端排序类型
|
||
p.IsAsc = strings.ReplaceAll(p.IsAsc, "ascending", "asc")
|
||
p.IsAsc = strings.ReplaceAll(p.IsAsc, "descending", "desc")
|
||
|
||
// 分割排序字段和排序方式
|
||
orderByArr := strings.Split(orderBy, ",")
|
||
isAscArr := strings.Split(p.IsAsc, ",")
|
||
|
||
// 验证参数
|
||
if len(isAscArr) != 1 && len(isAscArr) != len(orderByArr) {
|
||
return "", errors.New("排序参数有误")
|
||
}
|
||
|
||
// 构建排序SQL
|
||
var orderList []string
|
||
for i, orderByStr := range orderByArr {
|
||
var isAscStr string
|
||
if len(isAscArr) == 1 {
|
||
isAscStr = isAscArr[0]
|
||
} else {
|
||
isAscStr = isAscArr[i]
|
||
}
|
||
|
||
// 验证排序方向
|
||
if strings.ToLower(isAscStr) == "asc" {
|
||
orderList = append(orderList, orderByStr+" ASC")
|
||
} else if strings.ToLower(isAscStr) == "desc" {
|
||
orderList = append(orderList, orderByStr+" DESC")
|
||
} else {
|
||
return "", errors.New("排序参数有误")
|
||
}
|
||
}
|
||
|
||
return strings.Join(orderList, ", "), nil
|
||
}
|
||
|
||
// escapeOrderBySql 转义排序SQL防止SQL注入
|
||
func (p *PageQuery) escapeOrderBySql(orderBy string) string {
|
||
// 移除危险的SQL关键字
|
||
dangerousKeywords := []string{
|
||
"select", "insert", "update", "delete", "drop", "truncate",
|
||
"union", "join", "or", "and", "--", "#", "/*", "*/",
|
||
";", "'", "\"", "`",
|
||
}
|
||
|
||
orderBy = strings.ToLower(orderBy)
|
||
for _, keyword := range dangerousKeywords {
|
||
orderBy = strings.ReplaceAll(orderBy, keyword, "")
|
||
}
|
||
|
||
// 只允许字母、数字、下划线、点、逗号和空格
|
||
var safeStr strings.Builder
|
||
for _, r := range orderBy {
|
||
if unicode.IsLetter(r) || unicode.IsDigit(r) ||
|
||
r == '_' || r == '.' || r == ',' || r == ' ' {
|
||
safeStr.WriteRune(r)
|
||
}
|
||
}
|
||
|
||
return strings.TrimSpace(safeStr.String())
|
||
}
|
||
|
||
// toUnderScoreCase 驼峰转下划线
|
||
func (p *PageQuery) toUnderScoreCase(str string) string {
|
||
if str == "" {
|
||
return ""
|
||
}
|
||
|
||
var result strings.Builder
|
||
runes := []rune(str)
|
||
|
||
for i, r := range runes {
|
||
if unicode.IsUpper(r) {
|
||
if i > 0 && unicode.IsLower(runes[i-1]) {
|
||
result.WriteRune('_')
|
||
}
|
||
result.WriteRune(unicode.ToLower(r))
|
||
} else {
|
||
result.WriteRune(r)
|
||
}
|
||
}
|
||
|
||
return result.String()
|
||
}
|
||
|
||
// PageParam 分页参数结构
|
||
type PageParam struct {
|
||
PageNum int `json:"pageNum"`
|
||
PageSize int `json:"pageSize"`
|
||
Offset int `json:"offset"`
|
||
OrderBy string `json:"orderBy"`
|
||
IsAsc string `json:"isAsc"`
|
||
}
|
||
|
||
// PageResult 分页结果结构
|
||
type PageResult[T any] struct {
|
||
PageNum int `json:"pageNum"`
|
||
PageSize int `json:"pageSize"`
|
||
Total int64 `json:"total"`
|
||
TotalPage int `json:"totalPage"`
|
||
List []T `json:"list"`
|
||
}
|
||
|
||
// BuildPageResult 构建分页结果
|
||
func BuildPageResult[T any](list []T, total int64, pageNum, pageSize int) *PageResult[T] {
|
||
totalPage := 0
|
||
if pageSize > 0 {
|
||
totalPage = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||
}
|
||
|
||
return &PageResult[T]{
|
||
PageNum: pageNum,
|
||
PageSize: pageSize,
|
||
Total: total,
|
||
TotalPage: totalPage,
|
||
List: list,
|
||
}
|
||
}
|
||
|
||
// QueryPage 执行分页查询(通用方法)
|
||
func QueryPage[T any](db *sql.DB, baseQuery string, args []interface{}, pageQuery *PageQuery) (*PageResult[T], error) {
|
||
// 构建分页SQL
|
||
orderBy, limitOffset, pageArgs, err := pageQuery.BuildSQL()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 构建完整查询SQL
|
||
finalSQL := baseQuery
|
||
if orderBy != "" {
|
||
finalSQL += " " + orderBy
|
||
}
|
||
if limitOffset != "" {
|
||
finalSQL += " " + limitOffset
|
||
}
|
||
|
||
// 合并参数
|
||
finalArgs := args
|
||
if pageArgs != nil {
|
||
finalArgs = append(finalArgs, pageArgs...)
|
||
}
|
||
|
||
// 执行查询
|
||
rows, err := db.Query(finalSQL, finalArgs...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
// 解析结果集
|
||
var list []T
|
||
for rows.Next() {
|
||
var item T
|
||
// 注意:这里需要根据实际结构使用rows.Scan
|
||
// 这里是一个通用占位,实际使用需要具体实现
|
||
list = append(list, item)
|
||
}
|
||
|
||
// 查询总数
|
||
countSQL, countArgs := pageQuery.BuildCountSQL(baseQuery)
|
||
var total int64
|
||
if countArgs == nil {
|
||
countArgs = args
|
||
}
|
||
err = db.QueryRow(countSQL, countArgs...).Scan(&total)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 构建分页结果
|
||
pageNum := pageQuery.GetPageNum()
|
||
pageSize := pageQuery.GetPageSize()
|
||
return BuildPageResult(list, total, pageNum, pageSize), nil
|
||
}
|
||
|
||
// 使用示例
|
||
func ExampleUsage() {
|
||
// 1. 创建PageQuery
|
||
pageQuery := NewPageQuery()
|
||
pageSize := 10
|
||
pageNum := 1
|
||
pageQuery.WithParams(&pageNum, &pageSize).
|
||
WithOrder("createTime,id", "desc,asc")
|
||
|
||
// 2. 获取数据库连接
|
||
var db *sql.DB // 假设已初始化
|
||
|
||
// 3. 构建基础查询SQL
|
||
baseSQL := "SELECT id, name, price, create_time FROM goods WHERE status = ?"
|
||
args := []interface{}{1}
|
||
|
||
// 4. 获取分页SQL片段
|
||
orderBy, limitOffset, pageArgs, err := pageQuery.BuildSQL()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// 5. 构建完整SQL
|
||
finalSQL := baseSQL
|
||
if orderBy != "" {
|
||
finalSQL += " " + orderBy
|
||
}
|
||
if limitOffset != "" {
|
||
finalSQL += " " + limitOffset
|
||
}
|
||
|
||
// 6. 合并参数
|
||
finalArgs := append(args, pageArgs...)
|
||
|
||
// 7. 执行查询
|
||
rows, err := db.Query(finalSQL, finalArgs...)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
// 8. 查询总数
|
||
countSQL, countArgs := pageQuery.BuildCountSQL(baseSQL)
|
||
if countArgs == nil {
|
||
countArgs = args
|
||
}
|
||
|
||
var total int64
|
||
err = db.QueryRow(countSQL, countArgs...).Scan(&total)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// 9. 解析结果
|
||
var goods []Goods
|
||
for rows.Next() {
|
||
var g Goods
|
||
err := rows.Scan(&g.ID, &g.Name, &g.Price, &g.CreateTime)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
goods = append(goods, g)
|
||
}
|
||
|
||
// 10. 构建分页结果
|
||
result := BuildPageResult(goods, total, pageNum, pageSize)
|
||
_ = result
|
||
}
|
||
|
||
// Goods 示例结构体
|
||
type Goods struct {
|
||
ID int64 `db:"id"`
|
||
Name string `db:"name"`
|
||
Price float64 `db:"price"`
|
||
CreateTime string `db:"create_time"`
|
||
}
|