强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Go 语言完全指南 / 13 - 泛型:类型参数、约束、泛型函数、泛型数据结构

13 - 泛型(Generics)

13.1 泛型简介

Go 1.18 引入了泛型,允许编写适用于多种类型的通用代码。

// Go 1.18 之前:需要为每种类型写重复代码
func sumInts(nums []int) int {
    total := 0
    for _, n := range nums {
        total += n
    }
    return total
}

func sumFloats(nums []float64) float64 {
    total := 0.0
    for _, n := range nums {
        total += n
    }
    return total
}

// Go 1.18+:用泛型统一处理
func sum[T Number](nums []T) T {
    var total T
    for _, n := range nums {
        total += n
    }
    return total
}

13.2 类型参数和约束

package main

import (
    "fmt"
)

// 内置约束
// any           - 任意类型
// comparable    - 可比较类型(支持 == 和 !=)
// ~int          - 底层类型为 int 的所有类型
// int | string  - int 或 string

// 自定义约束
type Number interface {
    ~int | ~int8 | ~int16 | ~int32 | ~int64 |
    ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
    ~float32 | ~float64
}

type Signed interface {
    ~int | ~int8 | ~int16 | ~int32 | ~int64
}

type Unsigned interface {
    ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}

// 约束中嵌入接口
type Ordered interface {
    Number | ~string
}

func main() {
    fmt.Println(sum([]int{1, 2, 3, 4, 5}))           // 15
    fmt.Println(sum([]float64{1.1, 2.2, 3.3}))       // 6.6
    fmt.Println(sum([]int64{100, 200, 300}))          // 600
}

13.3 泛型函数

package main

import (
    "fmt"
    "strings"
)

// 基本泛型函数
func max[T Ordered](a, b T) T {
    if a > b {
        return a
    }
    return b
}

func min[T Ordered](a, b T) T {
    if a < b {
        return a
    }
    return b
}

// 多类型参数
func contains[T comparable](slice []T, target T) bool {
    for _, v := range slice {
        if v == target {
            return true
        }
    }
    return false
}

// Map 函数
func Map[T, U any](slice []T, fn func(T) U) []U {
    result := make([]U, len(slice))
    for i, v := range slice {
        result[i] = fn(v)
    }
    return result
}

// Filter 函数
func Filter[T any](slice []T, predicate func(T) bool) []T {
    var result []T
    for _, v := range slice {
        if predicate(v) {
            result = append(result, v)
        }
    }
    return result
}

// Reduce 函数
func Reduce[T, U any](slice []T, initial U, fn func(U, T) U) U {
    result := initial
    for _, v := range slice {
        result = fn(result, v)
    }
    return result
}

func main() {
    fmt.Println(max(3, 5))          // 5
    fmt.Println(max(3.14, 2.71))    // 3.14
    fmt.Println(max("apple", "banana")) // banana

    nums := []int{1, 2, 3, 4, 5}
    fmt.Println(contains(nums, 3))   // true
    fmt.Println(contains(nums, 6))   // false

    // Map
    doubled := Map(nums, func(n int) int { return n * 2 })
    fmt.Println(doubled) // [2 4 6 8 10]

    strings := Map(nums, func(n int) string { return fmt.Sprintf("#%d", n) })
    fmt.Println(strings) // [#1 #2 #3 #4 #5]

    // Filter
    evens := Filter(nums, func(n int) bool { return n%2 == 0 })
    fmt.Println(evens) // [2 4]

    // Reduce
    sum := Reduce(nums, 0, func(acc, n int) int { return acc + n })
    fmt.Println(sum) // 15
}

13.4 泛型类型

package main

import "fmt"

// 泛型栈
type Stack[T any] struct {
    items []T
}

func NewStack[T any]() *Stack[T] {
    return &Stack[T]{}
}

func (s *Stack[T]) Push(item T) {
    s.items = append(s.items, item)
}

func (s *Stack[T]) Pop() (T, bool) {
    var zero T
    if len(s.items) == 0 {
        return zero, false
    }
    item := s.items[len(s.items)-1]
    s.items = s.items[:len(s.items)-1]
    return item, true
}

func (s *Stack[T]) Peek() (T, bool) {
    var zero T
    if len(s.items) == 0 {
        return zero, false
    }
    return s.items[len(s.items)-1], true
}

func (s *Stack[T]) Len() int {
    return len(s.items)
}

// 泛型队列
type Queue[T any] struct {
    items []T
}

func (q *Queue[T]) Enqueue(item T) {
    q.items = append(q.items, item)
}

func (q *Queue[T]) Dequeue() (T, bool) {
    var zero T
    if len(q.items) == 0 {
        return zero, false
    }
    item := q.items[0]
    q.items = q.items[1:]
    return item, true
}

func main() {
    // 字符串栈
    stack := NewStack[string]()
    stack.Push("a")
    stack.Push("b")
    stack.Push("c")
    
    for v, ok := stack.Pop(); ok; v, ok = stack.Pop() {
        fmt.Println(v) // c, b, a
    }

    // 整数队列
    queue := &Queue[int]{}
    queue.Enqueue(1)
    queue.Enqueue(2)
    queue.Enqueue(3)
    
    for v, ok := queue.Dequeue(); ok; v, ok = queue.Dequeue() {
        fmt.Println(v) // 1, 2, 3
    }
}

13.5 泛型接口和约束

package main

import (
    "fmt"
    "sort"
)

// 容器约束
type Container[T any] interface {
    Add(item T)
    Get(index int) (T, bool)
    Len() int
}

// 可排序约束
type Sortable[T any] interface {
    Len() int
    Less(i, j int) bool
    Swap(i, j int)
}

// 泛型排序
func Sort[T Ordered](slice []T) {
    sort.Slice(slice, func(i, j int) bool {
        return slice[i] < slice[j]
    })
}

// 泛型集合
type Set[T comparable] struct {
    items map[T]struct{}
}

func NewSet[T comparable](items ...T) *Set[T] {
    s := &Set[T]{items: make(map[T]struct{})}
    for _, item := range items {
        s.Add(item)
    }
    return s
}

func (s *Set[T]) Add(item T) {
    s.items[item] = struct{}{}
}

func (s *Set[T]) Remove(item T) {
    delete(s.items, item)
}

func (s *Set[T]) Contains(item T) bool {
    _, ok := s.items[item]
    return ok
}

func (s *Set[T]) Len() int {
    return len(s.items)
}

func (s *Set[T]) Slice() []T {
    result := make([]T, 0, len(s.items))
    for item := range s.items {
        result = append(result, item)
    }
    return result
}

// 集合运算
func Union[T comparable](a, b *Set[T]) *Set[T] {
    result := NewSet[T]()
    for item := range a.items {
        result.Add(item)
    }
    for item := range b.items {
        result.Add(item)
    }
    return result
}

func Intersect[T comparable](a, b *Set[T]) *Set[T] {
    result := NewSet[T]()
    for item := range a.items {
        if b.Contains(item) {
            result.Add(item)
        }
    }
    return result
}

func main() {
    s1 := NewSet(1, 2, 3, 4)
    s2 := NewSet(3, 4, 5, 6)
    
    union := Union(s1, s2)
    fmt.Println("并集:", union.Slice())
    
    intersect := Intersect(s1, s2)
    fmt.Println("交集:", intersect.Slice())

    // 字符串集合
    tags := NewSet("go", "rust", "python")
    fmt.Println("包含 go:", tags.Contains("go"))
    
    nums := []int{5, 3, 1, 4, 2}
    Sort(nums)
    fmt.Println("排序:", nums) // [1 2 3 4 5]
}

13.6 类型推断

func main() {
    // 编译器可以推断类型参数
    fmt.Println(max(3, 5))       // T 推断为 int
    fmt.Println(max(3.14, 2.71)) // T 推断为 float64
    
    // 显式指定
    fmt.Println(max[int](3, 5))
    fmt.Println(max[float64](3.14, 2.71))
    
    // 无法推断时必须显式指定
    stack := NewStack[int]() // 必须指定
    stack.Push(1)
}

13.7 泛型限制

// ❌ 不支持的特性

// 1. 方法不能有额外的类型参数
// func (s Stack[T]) Map[U any](fn func(T) U) Stack[U] { } // 编译错误

// 2. 不能用类型参数做类型断言
// func foo[T any](x any) T { return x.(T) } // 编译错误

// 3. 不能用类型参数创建复合字面量
// func foo[T any]() []T { return []T{0} } // 如果 T 不是数值类型会出错

// 4. 不能用类型参数做指针操作
// func foo[T any](x T) *T { return &x } // OK,但有限制

// ✅ 工作方案
type SliceFuncs[T any] struct {
    data []T
}

func NewSliceFuncs[T any](data []T) SliceFuncs[T] {
    return SliceFuncs[T]{data: data}
}

func (sf SliceFuncs[T]) Map(fn func(T) T) SliceFuncs[T] {
    result := make([]T, len(sf.data))
    for i, v := range sf.data {
        result[i] = fn(v)
    }
    return SliceFuncs[T]{data: result}
}

13.8 性能考虑

// 泛型在编译时实例化,运行时无额外开销
// 但会导致二进制文件增大(每种类型生成一份代码)

import "testing"

func BenchmarkSumGeneric(b *testing.B) {
    nums := make([]int, 10000)
    for i := range nums {
        nums[i] = i
    }
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        sum(nums)
    }
}

func BenchmarkSumSpecific(b *testing.B) {
    nums := make([]int, 10000)
    for i := range nums {
        nums[i] = i
    }
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        sumInts(nums) // 专用函数
    }
}

// 泛型 vs interface{} 的性能对比
func BenchmarkContainsGeneric(b *testing.B) {
    nums := make([]int, 10000)
    for i := range nums { nums[i] = i }
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        contains(nums, 9999)
    }
}

func BenchmarkContainsInterface(b *testing.B) {
    nums := make([]any, 10000)
    for i := range nums { nums[i] = i }
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        for _, v := range nums {
            if v.(int) == 9999 { break }
        }
    }
}

🏢 业务场景

  1. 通用数据结构:泛型栈、队列、集合、链表
  2. 工具函数:Map/Filter/Reduce 等函数式操作
  3. API 响应:泛型 Response 包装不同类型数据
  4. 缓存系统:泛型 Cache 支持任意键值类型
  5. 仓储模式:泛型 Repository 接口

📖 扩展阅读