Python源码分析6 – 从CST到AST的转化

Introduction

上篇文章解释了Python是如何使用PyParser生成CST的。回顾一下,Python执行代码要经过如下过程:

1. Tokenizer进行词法分析,把源程序分解为Token

2. Parser根据Token创建CST

3. CST被转换为AST

4. AST被编译为字节码

5. 执行字节码

当执行Python代码的时候,以代码存放在文件中的情况为例,Python会调用PyParser_ASTFromFile函数将文件的代码内容转换为AST

mod_ty

PyParser_ASTFromFile(FILE *fp, const char *filename, int start, char *ps1,

char *ps2, PyCompilerFlags *flags, int *errcode,

PyArena *arena)

{

mod_ty mod;

perrdetail err;

node *n = PyParser_ParseFileFlags(fp, filename, &_PyParser_Grammar,

start, ps1, ps2, &err, PARSER_FLAGS(flags));

if (n) {

mod = PyAST_FromNode(n, flags, filename, arena);

PyNode_Free(n);

return mod;

}

else {

err_input(&err);

if (errcode)

*errcode = err.error;

return NULL;

}

}

PyParser_ParseFileFlags把文件转换成CST之后,PyAST_FromNode函数会把CST转换成AST。此函数定义在include\ast.h:

PyAPI_FUNC(mod_ty) PyAST_FromNode(const node *, PyCompilerFlags *flags,

const char *, PyArena *);

在分析此函数之前,我们先来看一下有关AST的一些基本的类型定义。

AST Types

AST所用到的类型均定义在Python_ast.h中,以stmt_ty类型为例:

enum _stmt_kind {FunctionDef_kind=1, ClassDef_kind=2, Return_kind=3,

Delete_kind=4, Assign_kind=5, AugAssign_kind=6, Print_kind=7,

For_kind=8, While_kind=9, If_kind=10, With_kind=11,

Raise_kind=12, TryExcept_kind=13, TryFinally_kind=14,

Assert_kind=15, Import_kind=16, ImportFrom_kind=17,

Exec_kind=18, Global_kind=19, Expr_kind=20, Pass_kind=21,

Break_kind=22, Continue_kind=23};

struct _stmt {

enum _stmt_kind kind;

union {

struct {

identifier name;

arguments_ty args;

asdl_seq *body;

asdl_seq *decorators;

} FunctionDef;

struct {

identifier name;

asdl_seq *bases;

asdl_seq *body;

} ClassDef;

struct {

expr_ty value;

} Return;

// ... 过长,中间从略

struct {

expr_ty value;

} Expr;

} v;

int lineno;

int col_offset;

};

typedef struct _stmt *stmt_ty;

stmt_ty是语句结点类型,实际上是_stmt结构的指针。_stmt结构比较长,但有着很清晰的Pattern

1. 第一个Fieldkind,代表语句的类型。_stmt_kind定义了_stmt的所有可能的语句类型,从函数定义语句,类定义语句直到Continue语句共有23种类型。

2. 接下来是一个union v,每个成员均为一个struct,分别对应_stmt_kind中的一种类型,如_stmt.v.FunctionDef对应了_stmt_kind枚举中的FunctionDef_Kind,也就是说,当_stmt.kind == FunctionDef_Kind时,_stmt.v.FunctionDef中保存的就是对应的函数定义语句的具体内容。

3. 其他数据,如linenocol_offset

大部分AST结点类型均是按照类似的pattern来定义的,不再赘述。除此之外,另外有一种比较简单的AST类型如operator_tyexpr_context_ty等,由于这些类型仍以_ty结尾,因此也可以认为是AST的结点,但实际上,这些类型只是简单的枚举类型,并非指针。因此在以后的文章中,并不把此类AST类型作为结点看待,而是作为简单的枚举处理。

由于每个AST类型会在union中引用其他的AST,这样层层引用,最后便形成了一颗AST树,试举例如下:

Python源码分析6 – 从CST到AST的转化

这颗AST树代表的是单条语句a+1

AST类型对应,在Python_ast.h / .c中定义了大量用于创建AST结点的函数,可以看作是AST结点的构造函数。以BinOp函数为例:

expr_ty

BinOp(expr_ty left, operator_ty op, expr_ty right, int lineno, int col_offset,

PyArena *arena)

{

expr_ty p;

if (!left) {

PyErr_SetString(PyExc_ValueError,

"field left is required for BinOp");

return NULL;

}

if (!op) {

PyErr_SetString(PyExc_ValueError,

"field op is required for BinOp");

return NULL;

}

if (!right) {

PyErr_SetString(PyExc_ValueError,

"field right is required for BinOp");

return NULL;

}

p = (expr_ty)PyArena_Malloc(arena, sizeof(*p));

if (!p) {

PyErr_NoMemory();

return NULL;

}

p->kind = BinOp_kind;

p->v.BinOp.left = left;

p->v.BinOp.op = op;

p->v.BinOp.right = right;

p->lineno = lineno;

p->col_offset = col_offset;

return p;

}

此函数只是根据传入的参数做一些简单的错误检查,分配内存,初始化对应的expr_ty类型,并返回指针。

adsl_seq & adsl_int_seq

在上面的stmt_ty定义中,如果稍微注意的话,可以发现其中大量用到了adsl_seq类型。类似在python_ast.h中其他AST类型中还会用到adsl_int_seq类型。adsl_seq & adsl_int_seq简单来说,是一个动态构造出的定长数组。Adsl_seqvoid *的数组:

typedef struct {

int size;

void *elements[1];

} asdl_seq;

adsl_int_seq则是int类型的数组:

typedef struct {

int size;

int elements[1];

} asdl_int_seq;

Size是数组长度,elements则是数组的元素。注意这些类型在定义elements时使用了一点技巧,定义的elements数组长度为1,而在动态分配内存的时候则是按照实际长度sizeof(adsl_seq) + size - 1来分配:

asdl_seq *

asdl_seq_new(int size, PyArena *arena)

{

asdl_seq *seq = NULL;

size_t n = sizeof(asdl_seq) +

(size ? (sizeof(void *) * (size - 1)) : 0);

seq = (asdl_seq *)PyArena_Malloc(arena, n);

if (!seq) {

PyErr_NoMemory();

return NULL;

}

memset(seq, 0, n);

seq->size = size;

return seq;

}

这样既可以动态分配数组元素,也可以很方便的用elements来访问数组元素。

用如下的宏和函数可以操作adsl_seq / adsl_int_seq :

asdl_seq *asdl_seq_new(int size, PyArena *arena);

asdl_int_seq *asdl_int_seq_new(int size, PyArena *arena);

#define asdl_seq_GET(S, I) (S)->elements[(I)]

#define asdl_seq_LEN(S) ((S) == NULL ? 0 : (S)->size)

#ifdef Py_DEBUG

#define asdl_seq_SET(S, I, V) { \

int _asdl_i = (I); \

assert((S) && _asdl_i < (S)->size); \

(S)->elements[_asdl_i] = (V); \

}

#else

#define asdl_seq_SET(S, I, V) (S)->elements[I] = (V)

#endif

需要说明的是adsl_seq / adsl_int_seq均是从PyArena中分配出,PyArena会在以后的文章中详细分析,目前我们可以暂时把PyArena简单看作一个分配内存用的堆。

From CST to AST

如前所述,PyAST_FromNode负责从CSTAST的转换。简单来说,此函数会深度遍历整棵CST,过滤掉CST中的多余信息,只是将有意义的CST子树转换成AST结点构造出AST树。

PyAst_FromNode函数的大致代码如下:

mod_ty

PyAST_FromNode(const node *n, PyCompilerFlags *flags, const char *filename,

PyArena *arena)

{

...

switch (TYPE(n)) {

case file_input:

stmts = asdl_seq_new(num_stmts(n), arena);

if (!stmts)

return NULL;

for (i = 0; i < NCH(n) - 1; i++) {

ch = CHILD(n, i);

if (TYPE(ch) == NEWLINE)

continue;

REQ(ch, stmt);

num = num_stmts(ch);

if (num == 1) {

s = ast_for_stmt(&c, ch);

if (!s)

goto error;

asdl_seq_SET(stmts, k++, s);

}

else {

ch = CHILD(ch, 0);

REQ(ch, simple_stmt);

for (j = 0; j < num; j++) {

s = ast_for_stmt(&c, CHILD(ch, j * 2));

if (!s)

goto error;

asdl_seq_SET(stmts, k++, s);

}

}

}

return Module(stmts, arena);

case eval_input: {

...

}

case single_input: {

...

}

default:

goto error;

}

可以看到PyAst_FromNode根据N的类型作了不同处理,以file_input为例,file_input的产生式(在Grammar文件中定义)如下:File_input : (NEWLINE | stmt)* ENDMARKER,对应的PyAst_FromNode的代码作了如下事情:

1. 调用num_stmts(n)计算出所有顶层语句的个数,并创建出合适大小的adsl_seq结构以存放这些语句

2. 对于file_input结点的所有子结点作如下处理: file_input: ( NEW_LINE | stmt )* ENDMARKER

a. 忽略掉NEW_LINE,换行无需处理

b. REQ(ch, stmt)断言ch的类型必定为stmt,从产生式可以得出此结论

c. 计算出子结点stmt的语句条数n

i. N == 1,说明stmt对应单条语句,调用ast_for_stmt遍历stmt对应得CST子树,生成对应的AST子树,并调用adsl_seq_SET设置到数组之中。这样AST的根结点mod_ty便可以知道有哪些顶层的语句(stmt),这些语句结点便是根结点mod_ty的子结点。

ii. N > 1,说明stmt对应多条语句。根据Grammar文件中定义的如下产生式可以推知此时ch的子结点必然为simple_stmt

stmt: simple_stmt | compound_stmt

simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE

small_stmt: (expr_stmt | print_stmt | del_stmt | pass_stmt | flow_stmt |

import_stmt | global_stmt | exec_stmt | assert_stmt)

compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef

由于simple_stmt的定义中small_stmt’;’总是成对出现,因此index为偶数的CST结点便是所需的单条顶层语句的结点,对于每个这样的结点调用adsl_seq_SET设置到数组之中

3. 最后,调用Module函数从stmts数组生成mod_ty结点,也就是AST的根结点

上面的过程中用到了两个关键函数:num_stmtsast_for_stmt。先来看num_stmts函数:

static int

num_stmts(const node *n)

{

int i, l;

node *ch;

switch (TYPE(n)) {

case single_input:

if (TYPE(CHILD(n, 0)) == NEWLINE)

return 0;

else

return num_stmts(CHILD(n, 0));

case file_input:

l = 0;

for (i = 0; i < NCH(n); i++) {

ch = CHILD(n, i);

if (TYPE(ch) == stmt)

l += num_stmts(ch);

}

return l;

case stmt:

return num_stmts(CHILD(n, 0));

case compound_stmt:

return 1;

case simple_stmt:

return NCH(n) / 2; /* Divide by 2 to remove count of semi-colons */

case suite:

if (NCH(n) == 1)

return num_stmts(CHILD(n, 0));

else {

l = 0;

for (i = 2; i < (NCH(n) - 1); i++)

l += num_stmts(CHILD(n, i));

return l;

}

default: {

char buf[128];

sprintf(buf, "Non-statement found: %d %d\n",

TYPE(n), NCH(n));

Py_FatalError(buf);

}

}

assert(0);

return 0;

}

此函数比较简单,根据结点类型和产生式递归计算顶层语句的个数。所谓顶层语句,也就是把复合语句(compound_stmt)看作单条语句,复合语句中的内部的语句不做计算,当然普通的简单语句(small_stmt) 也是算1条语句。下面根据不同结点类型分析此函数:

1. Single_input

代表单条交互语句,对应的产生式:single_input: NEWLINE | simple_stmt | compound_stmt NEWLINE
如果single_input的第一个子结点为NEW_LINE,说明无语句,返回0,否则说明是simple_stmt或者compound_stmt NEWLINE,可以直接递归调用num_stmts处理

2. File_input

代表整个代码文件,对应的产生式:file_input: (NEWLINE | stmt)* ENDMARKER
只需要反复对每个子结点调用num_stmts既可。

3. Stmt

代表语句,对应的产生式:stmt: simple_stmt | compound_stmt
对第一个子结点调用num_stmts既可。

4. Compound_stmt

代表复合语句,对应的产生式:compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef
compound_stmt只可能有单个子结点,而且必然代表单条顶层的语句,因此无需继续遍历,直接返回1既可。

5. Simple_stmt

代表简单语句(非复合语句)的集合,对应的产生式:simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE

可以看到顶层语句数=子结点数/2 (去掉多余的分号和NEWLINE

6. Suite

代表复合语句中的语句块,也就是冒号之后的部分(如:classdef: 'class' NAME ['(' [testlist] ')'] ':' suite),类似于C/C++大括号中的内容,对应的产生式如下:suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT

子结点数为1,说明必然是simple_stmt,可以直接调用num_stmts处理,否则,说明是多个stmt的集合,遍历所有子结点调用num_stmts并累加既可

可以看到,num_stmts基本上是和语句有关的产生式是一一对应的。

接下来分析ast_for_stmts的内容:

static stmt_ty

ast_for_stmt(struct compiling *c, const node *n)

{

if (TYPE(n) == stmt) {

assert(NCH(n) == 1);

n = CHILD(n, 0);

}