package utils import ( "database/sql" "errors" "fmt" "strings" "unicode" ) // PageQuery 分页查询实体类 type PageQuery struct { PageSize *int `json:"pageSize" form:"pageSize"` // 分页大小 PageNum *int `json:"pageNum" form:"pageNum"` // 当前页数 OrderByColumn string `json:"orderByColumn" form:"orderByColumn"` // 排序列 IsAsc string `json:"isAsc" form:"isAsc"` // 排序的方向desc或者asc // 常量定义 DefaultPageNum int // 当前记录起始索引 默认值 DefaultPageSize int // 每页显示记录数 默认值 MaxPageSize int // 最大分页大小限制 MinSize int // 最小分页大小 } // NewPageQuery 创建新的PageQuery实例 func NewPageQuery() *PageQuery { return &PageQuery{ DefaultPageNum: 1, DefaultPageSize: 20, MaxPageSize: 1000, MinSize: 0, } } // WithParams 设置分页参数 func (p *PageQuery) WithParams(pageNum, pageSize *int) *PageQuery { p.PageNum = pageNum p.PageSize = pageSize return p } // WithOrder 设置排序参数 func (p *PageQuery) WithOrder(orderByColumn, isAsc string) *PageQuery { p.OrderByColumn = orderByColumn p.IsAsc = isAsc return p } // BuildSQL 构建分页查询的SQL片段 // 返回: ORDER BY 子句, LIMIT OFFSET 子句, 参数值, 错误 func (p *PageQuery) BuildSQL() (orderBy string, limitOffset string, args []interface{}, err error) { // 设置分页参数 pageNum := p.getPageNum() pageSize := p.getPageSize() // 验证分页参数 if pageNum <= 0 { pageNum = p.DefaultPageNum } if pageSize <= 0 { pageSize = p.DefaultPageSize } // 限制最大分页大小 if pageSize > p.MaxPageSize { return "", "", nil, errors.New(fmt.Sprintf("分页大小不能超过 %d", p.MaxPageSize)) } // 计算偏移量 offset := (pageNum - 1) * pageSize // 构建排序条件 if p.OrderByColumn != "" && p.IsAsc != "" { orderBy, err = p.buildOrder() if err != nil { return "", "", nil, err } if orderBy != "" { orderBy = "ORDER BY " + orderBy } } // 构建分页条件 if pageSize > 0 { limitOffset = "LIMIT ? OFFSET ?" args = []interface{}{pageSize, offset} } return orderBy, limitOffset, args, nil } // BuildCountSQL 构建统计总数的SQL func (p *PageQuery) BuildCountSQL(baseSQL string) (string, []interface{}) { // 将SELECT语句转换为COUNT语句 lowerSQL := strings.ToLower(baseSQL) // 移除ORDER BY子句(对于COUNT不需要排序) if idx := strings.Index(lowerSQL, "order by"); idx != -1 { baseSQL = baseSQL[:idx] } // 将SELECT ... FROM 替换为 COUNT(*) selectIdx := strings.Index(lowerSQL, "select") fromIdx := strings.Index(lowerSQL, "from") if selectIdx != -1 && fromIdx != -1 { countSQL := fmt.Sprintf("SELECT COUNT(*) %s", baseSQL[fromIdx:]) return countSQL, nil } // 如果无法处理,返回原样 return baseSQL, nil } // BuildPage 构建分页对象(返回通用的分页结构) func (p *PageQuery) BuildPage() (*PageParam, error) { // 设置合理的默认值 if p.DefaultPageNum <= 0 { p.DefaultPageNum = 1 } if p.DefaultPageSize <= 0 { p.DefaultPageSize = 20 } if p.MaxPageSize <= 0 { p.MaxPageSize = 1000 // 设置默认最大值 } pageNum := p.getPageNum() pageSize := p.getPageSize() // 验证 if pageNum <= 0 { pageNum = p.DefaultPageNum } if pageSize <= 0 { pageSize = p.DefaultPageSize } // 添加最小分页大小限制 if pageSize < 1 { pageSize = p.DefaultPageSize } if pageSize > p.MaxPageSize { // 提供更友好的错误信息 return nil, fmt.Errorf("分页大小 %d 超过最大限制 %d", pageSize, p.MaxPageSize) } offset := (pageNum - 1) * pageSize return &PageParam{ PageNum: pageNum, PageSize: pageSize, Offset: offset, OrderBy: p.OrderByColumn, IsAsc: p.IsAsc, }, nil } // GetFirstNum 获取起始记录索引 func (p *PageQuery) GetFirstNum() int { pageNum := p.getPageNum() pageSize := p.getPageSize() return (pageNum - 1) * pageSize } // GetPageNum 获取页码 func (p *PageQuery) GetPageNum() int { return p.getPageNum() } // GetPageSize 获取分页大小 func (p *PageQuery) GetPageSize() int { return p.getPageSize() } // getPageNum 获取页码(内部方法) func (p *PageQuery) getPageNum() int { if p.PageNum == nil || *p.PageNum <= 0 { return p.DefaultPageNum } return *p.PageNum } // getPageSize 获取分页大小(内部方法) func (p *PageQuery) getPageSize() int { if p.PageSize == nil || *p.PageSize <= 0 { return p.DefaultPageSize } return *p.PageSize } // buildOrder 构建排序条件 func (p *PageQuery) buildOrder() (string, error) { orderBy := p.escapeOrderBySql(p.OrderByColumn) orderBy = p.toUnderScoreCase(orderBy) // 兼容前端排序类型 p.IsAsc = strings.ReplaceAll(p.IsAsc, "ascending", "asc") p.IsAsc = strings.ReplaceAll(p.IsAsc, "descending", "desc") // 分割排序字段和排序方式 orderByArr := strings.Split(orderBy, ",") isAscArr := strings.Split(p.IsAsc, ",") // 验证参数 if len(isAscArr) != 1 && len(isAscArr) != len(orderByArr) { return "", errors.New("排序参数有误") } // 构建排序SQL var orderList []string for i, orderByStr := range orderByArr { var isAscStr string if len(isAscArr) == 1 { isAscStr = isAscArr[0] } else { isAscStr = isAscArr[i] } // 验证排序方向 if strings.ToLower(isAscStr) == "asc" { orderList = append(orderList, orderByStr+" ASC") } else if strings.ToLower(isAscStr) == "desc" { orderList = append(orderList, orderByStr+" DESC") } else { return "", errors.New("排序参数有误") } } return strings.Join(orderList, ", "), nil } // escapeOrderBySql 转义排序SQL防止SQL注入 func (p *PageQuery) escapeOrderBySql(orderBy string) string { // 移除危险的SQL关键字 dangerousKeywords := []string{ "select", "insert", "update", "delete", "drop", "truncate", "union", "join", "or", "and", "--", "#", "/*", "*/", ";", "'", "\"", "`", } orderBy = strings.ToLower(orderBy) for _, keyword := range dangerousKeywords { orderBy = strings.ReplaceAll(orderBy, keyword, "") } // 只允许字母、数字、下划线、点、逗号和空格 var safeStr strings.Builder for _, r := range orderBy { if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '.' || r == ',' || r == ' ' { safeStr.WriteRune(r) } } return strings.TrimSpace(safeStr.String()) } // toUnderScoreCase 驼峰转下划线 func (p *PageQuery) toUnderScoreCase(str string) string { if str == "" { return "" } var result strings.Builder runes := []rune(str) for i, r := range runes { if unicode.IsUpper(r) { if i > 0 && unicode.IsLower(runes[i-1]) { result.WriteRune('_') } result.WriteRune(unicode.ToLower(r)) } else { result.WriteRune(r) } } return result.String() } // PageParam 分页参数结构 type PageParam struct { PageNum int `json:"pageNum"` PageSize int `json:"pageSize"` Offset int `json:"offset"` OrderBy string `json:"orderBy"` IsAsc string `json:"isAsc"` } // PageResult 分页结果结构 type PageResult[T any] struct { PageNum int `json:"pageNum"` PageSize int `json:"pageSize"` Total int64 `json:"total"` TotalPage int `json:"totalPage"` List []T `json:"list"` } // BuildPageResult 构建分页结果 func BuildPageResult[T any](list []T, total int64, pageNum, pageSize int) *PageResult[T] { totalPage := 0 if pageSize > 0 { totalPage = int((total + int64(pageSize) - 1) / int64(pageSize)) } return &PageResult[T]{ PageNum: pageNum, PageSize: pageSize, Total: total, TotalPage: totalPage, List: list, } } // QueryPage 执行分页查询(通用方法) func QueryPage[T any](db *sql.DB, baseQuery string, args []interface{}, pageQuery *PageQuery) (*PageResult[T], error) { // 构建分页SQL orderBy, limitOffset, pageArgs, err := pageQuery.BuildSQL() if err != nil { return nil, err } // 构建完整查询SQL finalSQL := baseQuery if orderBy != "" { finalSQL += " " + orderBy } if limitOffset != "" { finalSQL += " " + limitOffset } // 合并参数 finalArgs := args if pageArgs != nil { finalArgs = append(finalArgs, pageArgs...) } // 执行查询 rows, err := db.Query(finalSQL, finalArgs...) if err != nil { return nil, err } defer rows.Close() // 解析结果集 var list []T for rows.Next() { var item T // 注意:这里需要根据实际结构使用rows.Scan // 这里是一个通用占位,实际使用需要具体实现 list = append(list, item) } // 查询总数 countSQL, countArgs := pageQuery.BuildCountSQL(baseQuery) var total int64 if countArgs == nil { countArgs = args } err = db.QueryRow(countSQL, countArgs...).Scan(&total) if err != nil { return nil, err } // 构建分页结果 pageNum := pageQuery.GetPageNum() pageSize := pageQuery.GetPageSize() return BuildPageResult(list, total, pageNum, pageSize), nil } // 使用示例 func ExampleUsage() { // 1. 创建PageQuery pageQuery := NewPageQuery() pageSize := 10 pageNum := 1 pageQuery.WithParams(&pageNum, &pageSize). WithOrder("createTime,id", "desc,asc") // 2. 获取数据库连接 var db *sql.DB // 假设已初始化 // 3. 构建基础查询SQL baseSQL := "SELECT id, name, price, create_time FROM goods WHERE status = ?" args := []interface{}{1} // 4. 获取分页SQL片段 orderBy, limitOffset, pageArgs, err := pageQuery.BuildSQL() if err != nil { panic(err) } // 5. 构建完整SQL finalSQL := baseSQL if orderBy != "" { finalSQL += " " + orderBy } if limitOffset != "" { finalSQL += " " + limitOffset } // 6. 合并参数 finalArgs := append(args, pageArgs...) // 7. 执行查询 rows, err := db.Query(finalSQL, finalArgs...) if err != nil { panic(err) } defer rows.Close() // 8. 查询总数 countSQL, countArgs := pageQuery.BuildCountSQL(baseSQL) if countArgs == nil { countArgs = args } var total int64 err = db.QueryRow(countSQL, countArgs...).Scan(&total) if err != nil { panic(err) } // 9. 解析结果 var goods []Goods for rows.Next() { var g Goods err := rows.Scan(&g.ID, &g.Name, &g.Price, &g.CreateTime) if err != nil { panic(err) } goods = append(goods, g) } // 10. 构建分页结果 result := BuildPageResult(goods, total, pageNum, pageSize) _ = result } // Goods 示例结构体 type Goods struct { ID int64 `db:"id"` Name string `db:"name"` Price float64 `db:"price"` CreateTime string `db:"create_time"` }