AST(抽象语法树)-代码补丁

文章目录
  1. 1. AST基础
  2. 2. 一个案例
    1. 2.1. 转换代码
    2. 2.2. 结果

Pandas1.x 弃用了 Pandas 0.25.x 中的某些函数。随着库的升级,这些代码都要修改。但是公司的代码有上百万行!怎么应付?

ast在某些场合中可以是开发的得力工具,就比如在代码移植及代码质量评估中。在python2.x 到 python3.x的代码转换中,亦或是Pandas 0.25.xPandas 1.x的代码转化。Pandas1.x 弃用了 Pandas 0.25.x 中的某些函数,因此在代码转换过程中需要修改这部分的函数调用方式。ast使得我们可以忽视代码中的注释、空格等信息,而直接关注于代码本身。

AST基础

Python 有库ast支持生成代码的AST:

1
2
3
4
5
6
7
8
9
10
import ast

code = """
a = 1
print(a)
"""

head = ast.parse(code)
print(head)
# <_ast.Module object at 0x00xxxx>

ast库提供了一个dump方法返回以节点为根的整颗树(格式化之后)

1
2
3
4
5
print(ast.dump(head))
## outputs(上述代码输出如下):
Module(
body=[Assign(targets=[Name(id='a')], value=Num(n=1)),
Expr(value=Call(func=Name(id='print'), args=[Name(id='a')], keywords=[]))])

可以看到head节点是Module类型,它有一个属性body,其值为一个包含两个节点的list。一个代表a = 2 ,另一个代表 print(a)。第一个节点有一个targets属性表示左侧(LHS)的a 以及一个value属性代表右侧(RHS)的1。

试试将右侧的value的n属性替换为2

1
2
3
4
5
6
head.body[0].value.n = 2
print(ast.dump(head))
## outputs(上述代码输出如下):
Module(
body=[Assign(targets=[Name(id='a')], value=Num(n=2)),
Expr(value=Call(func=Name(id='print'), args=[Name(id='a')], keywords=[]))])

可以看到其值被更改为了2,现在将AST转换回代码

1
2
3
4
5
6
import astunparse

print(astunparse.unparse(head))
## outputs(上述代码输出如下):
a = 2
print(a)

可以看到代码中的1被替换为了2.

一个案例

Pandas1.0.0中,多维索引MultiIndexname属性不再支持以=的方式更新,改为使用index.set_names()进行更新。

1
2
3
4
5
6
7
8
9
10
11
12
#pandas=0.25.x
import pandas as pd

mi = pd.MultiIndex.from_product([[1, 2], ['a', 'b']], names=['x', 'y'])
print(mi.levels[0].name)

mi.levels[0].name = "new name"
print(mi.levels[0].name)

## outputs(上述代码输出如下):
x
new name

上面的第7行代码在Pandas=1.0.0中应该改为

1
mi = mi.set_names("new name", level=0)

该怎么使用程序完成mi.levels[0].name = "new name"mi = mi.set_names("new name", level=0)的转换呢?

借助AST可以达到目的。转换算法如下

  1. 构造源代码的AST,并且遍历树
  2. 识别该节点是否代表如下形式的代码:<var>.levels[<idx>].name=<val>.
  3. 如果代表,则对该节点进行替换,替换为如下形式:<var>=<var>.set_names(<val>, level=<idx>).

转换代码

首先来看看源代码和目标代码的AST形式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# src code
mi.levels[0].name = "new name"

Module(
body=[
Assign(
targets=[
Attribute(
value=Subscript(value=Attribute(value=Name(id='mi'), attr='levels'), slice=Index(value=Num(n=0))),
attr='name')],
value=Str(s='new name'))])

# dst code
mi = mi.set_names("new name", level=0)

Module(
body=[
Assign(targets=[Name(id='mi')],
value=Call(func=Attribute(value=Name(id='mi'), attr='set_names'),
args=[Str(s='new name')],
keywords=[keyword(arg='level', value=Num(n=0))]))])

经过前述分析,结合转换算法,有如下的转换代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import ast

def is_multi_index_rename_node(node):
"""
Checks if the given node represents the code: <var>.levels[<idx>].name = <val>
and returns the corresponding var, idx and val if it does.
"""
try:
if (
isinstance(node, ast.Assign)
and node.targets[0].attr == "name"
and node.targets[0].value.value.attr == "levels"
):
var = node.targets[0].value.value.value.id
idx = node.targets[0].value.slice.value.n
val = node.value

return True, var, idx, val
except:
pass
return False, None, None, None

def get_new_multi_index_rename_node(var, idx, val):
"""
Returns AST node that represents the code: <var> = <var>.set_names(<val>, level=<idx>)
for the given var, idx and val.
"""
return ast.Assign(
targets=[ast.Name(id=var)],
value=ast.Call(
func=ast.Attribute(value=ast.Name(id=var), attr="set_names"),
args=[val],
keywords=[ast.keyword(arg="level", value=ast.Num(n=idx))],
),
)

def patch(node):
"""
Takes an AST rooted at the give node and patches it.
"""
# If it is a leaf node, then no patching needed.
if not hasattr(node, "_fields"):
return node

# For every child of the node, modify it if needed and recursively call patch on it.
for (name, field) in ast.iter_fields(node):
if isinstance(field, list):
for i in range(len(field)):
check, var, idx, val = is_multi_index_rename_node(field[i])
if check:
field[i] = get_new_multi_index_rename_node(var, idx, val)
else:
patch(field[i])
else:
check, var, idx, val = is_multi_index_rename_node(field)
if check:
setattr(node, name, get_new_multi_index_rename_node(var, idx, val))
else:
patch(field)

结果

1
2
3
4
5
6
7
8
9
10
11
print(f"previous code: {prev_code}")
prev_ast = ast.parse(prev_code)
patch(prev_ast)
print("after transformation: ", astunparse.unparse(prev_ast))

## outputs(上述代码输出如下):
previous code:
mi.levels[0].name = "new name"

after transformation: 
mi = mi.set_names('new name', level=0)

可以看到,借由AST我们可以成功地将之前版本的代码

mi.levels[0].name = "new name"

转换到了最新版本的代码

mi = mi.set_names('new name', level=0)