您现在的位置是:网站首页 > 博客日记 >

python高阶函数模块-functools

作者:YXN-python 阅读量:10 发布日期:2025-09-04

functools 模块是 Python 标准库中一个用于高阶函数的工具模块。

所谓高阶函数,就是指那些操作或返回其他函数的函数。

functools 提供了一系列强大的工具,用于装饰器、缓存、比较、偏函数等功能,能帮助你编写更简洁、高效和可复用的代码。

 

1. @functools.lru_cache - 缓存装饰器

当使用相同的参数再次调用该函数时,它会直接返回缓存的结果,而无需重新计算。

应用场景:适用于计算昂贵、纯函数(输出只由输入决定)的函数,如递归计算(斐波那契数列)、数据查询、API调用等。

主要参数:

  • maxsize: 指定缓存的最大项数。设为 None 表示缓存大小无限制(但可能导致内存问题),默认值为 128
  • typed: 如果设置为 True,则不同类型的参数(如 33.0)会被区别对待并分别缓存。默认为 False

示例:

import functools
import time

# 没有缓存的版本,效率极低
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

# 使用 lru_cache 的版本,效率极高
@functools.lru_cache(maxsize=None)
def fib_cached(n):
    if n < 2:
        return n
    return fib_cached(n-1) + fib_cached(n-2)

# 测试
start = time.time()
result = fib(35)
end = time.time()
print(f"没有缓存: {result} 花了 {end - start:.4f}s")

start = time.time()
result = fib_cached(35)
end = time.time()
print(f"带缓存: {result} 花了 {end - start:.4f}s")

# 后续调用相同参数会瞬间返回结果
start = time.time()
result = fib_cached(35)
end = time.time()
print(f"使用缓存: {result} 花了 {end - start:.8f}s") # 时间几乎为0

输出:

没有缓存: 9227465 花了 2.2199s
带缓存: 9227465 花了 0.0000s
使用缓存: 9227465 花了 0.00000000s

 

2. functools.partial - 创建偏函数

用于“冻结”函数的部分参数(和/或关键字参数),从而创建一个具有预设参数的新函数。

这在你需要频繁调用某个函数,但某些参数又总是相同的情况下非常有用。

应用场景:简化 API 调用,固定回调函数的参数,创建更简洁的函数别名。

示例:

import functools

def power(base, exponent):
    return base ** exponent

# 计算平方,固定 exponent=2
square = functools.partial(power, exponent=2)
# 计算立方,固定 exponent=3
cube = functools.partial(power, exponent=3)

print(square(4))  # 4²=16
print(cube(3))    # 3³=27

# 二进制转换(固定 base=2)
int2 = functools.partial(int, base=2)
print(int2('1010'))  # 输出 10

 

3. functools.wraps - 装饰器的得力助手

这是一个装饰器工厂,用于编写自己的装饰器时保留原函数的元数据(如函数名 __name__、文档字符串 __doc__ 等)。

如果不使用 wraps,被装饰后的函数会“丢失”自己的身份,变得难以调试。

示例:对比使用和不使用 wraps 的区别

import functools

# 一个不规范的装饰器(没有使用 @wraps)
def my_decorator_bad(func):
    def wrapper(*args, **kwargs):
        """Wrapper docstring."""
        print("在函数被调用之前发生了一些事情.")
        return func(*args, **kwargs)
    return wrapper

# 一个规范的装饰器(使用了 @wraps)
def my_decorator_good(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        """Wrapper docstring."""
        print("在函数被调用之前发生了一些事情。")
        return func(*args, **kwargs)
    return wrapper

@my_decorator_bad
def say_hello_bad():
    """Says hello badly."""
    print("Hello!")

@my_decorator_good
def say_hello_good():
    """Says hello well."""
    print("Hello!")

print(say_hello_bad.__name__)  # 输出 'wrapper' (错了!)
print(say_hello_bad.__doc__)   # 输出 'Wrapper docstring.' (错了!)

print(say_hello_good.__name__) # 输出 'say_hello_good' (正确!)
print(say_hello_good.__doc__)  # 输出 'Says hello well.' (正确!)

 

4. functools.total_ordering - 简化富比较操作

这是一个类装饰器。它允许你只定义 __lt__(), __le__(), __gt__(), 或 __ge__() 中的一个,以及 __eq__() 方法,它会自动为你填充其余的比较方法。

应用场景:当你定义一个需要支持所有比较操作(<, <=, ==, !=, >, >=)的类时,可以大大减少模板代码。

示例:

import functools

@functools.total_ordering
class Student:
    def __init__(self, name, grade):
        self.name = name
        self.grade = grade

    # 我们只定义 ‘小于‘ 和 ‘等于‘ 的逻辑
    def __eq__(self, other):
        return self.grade == other.grade

    def __lt__(self, other):
        return self.grade < other.grade

# 现在这个类自动拥有了所有比较方法
s1 = Student("Alice", 85)
s2 = Student("Bob", 90)

print(s1 < s2)   # True
print(s1 <= s2)  # True
print(s1 > s2)   # False
print(s1 >= s2)  # False
print(s1 == s2)  # False

 

5. functools.reduce - 累积计算

(注意:在 Python 3 中,reduce() 被移到了 functools 模块中)

将一个可迭代对象中的所有元素通过一个带有两个参数的函数进行累积计算,最终得到一个单一的结果。

工作原理:reduce(function, iterable[, initializer])

  • function: 一个有两个参数的函数。
  • iterable: 可迭代对象。
  • initializer (可选): 初始值。

示例:

import functools
import operator

# 计算列表元素的乘积
numbers = [1, 2, 3, 4, 5]
product = functools.reduce(operator.mul, numbers)  # (1*2*3*4*5)
print(product)  # 输出 120 

# 拼接字符串列表
words = ['Hello', ' ', 'World', '!']
sentence = functools.reduce(lambda x, y: x + y, words)
print(sentence)  # 输出 'Hello World!'

# 使用初始值
result = functools.reduce(operator.add, numbers, 100)  # (100 + 1 + 2 + 3 + 4 + 5)
print(result)  # 输出 115 

 

YXN-python

2025-09-04