Parsing Abstract Syntax Trees (ASTs)

Posted on , 5 min read

From the below sample.py file, we’ll extract all the function calls as a list.

# sample.py
import foo

data = open('file')                     # a function call
foo.bar(arg=data)                       # a function call
foo.bar(arg=foo.meow(foo.z(arg=data)))  # three function calls
foo.woof(foo.x.y(arg=data))             # two function calls

Expected output

>>> funcs_that_have_been_called
['open', 'foo.bar', 'foo.bar', 'foo.meow', 'foo.z', 'foo.woof', 'foo.x.y']

(Note: not in that order)

We’ll make use of the ast module to inspect the Abstract Syntax Tree of sample.py to achieve this.

>>> import ast
>>> tree = ast.parse(open('sample.py').read())
>>> ast.dump(tree)
"Module(body=[Import(names=[alias(name='foo', asname=None)]), Assign(targets=[Name(id='data', ctx=Store())], value=Call(func=Name(id='open', ctx=Load()), args=[Str(s='file')], keywords=[], starargs=None, kwargs=None)), Expr(value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='bar', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)), Expr(value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='bar', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='meow', ctx=Load()), args=[Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='z', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)], keywords=[], starargs=None, kwargs=None))], starargs=None, kwargs=None)), Expr(value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='woof', ctx=Load()), args=[Call(func=Attribute(value=Attribute(value=Name(id='foo', ctx=Load()), attr='x', ctx=Load()), attr='y', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)], keywords=[], starargs=None, kwargs=None))])"

The abstract syntax tree above has the root ast object (root node) — which is Module(), and Module(body=[...]) is where body is a list of some child ast objects (or child nodes).

At first glance, the above string has an overwhelming amount of stuff in it, and for good reason — it is the complete AST that is representative of the entire sample.py code. But let’s break it down step-by-step.


Breaking it down step by step
>>> tree
<_ast.Module object at 0x031DF970>

Note that ast.dump recursively goes through the entire AST and prints it, but however just entering tree in the REPL will give us the root node or the Module object (the module object further contains multiple children which we’ve already seen in that big AST string above).

Since we know (from ast.dump) that the Module object has a body in it, let’s check that out in the REPL.

>>> tree.body
[<_ast.Import object at 0x031DF6F0>, <_ast.Assign object at 0x031DFF30>, <_ast.Expr object at 0x031DFDD0>, <_ast.Expr object at 0x031DFED0>, <_ast.Expr object at 0x0321F790>]

Now it’s rather obvious that each item of the above list represents a statement from our sample.py, and we can confirm this, as our sample.py had 5 LOC and len(tree.body) is also 5.

Let’s iterate and inspect further.

>>> for statement in tree.body:
...     print ast.dump(statement), '\n'
...
Import(names=[alias(name='foo', asname=None)])

Assign(targets=[Name(id='data', ctx=Store())], value=Call(func=Name(id='open', ctx=Load()), args=[Str(s='file')], keywords=[], starargs=None, kwargs=None))

Expr(value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='bar', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None))

Expr(value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='bar', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='meow', ctx=Load()), args=[Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='z', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)], keywords=[], starargs=None, kwargs=None))], starargs=None, kwargs=None))

Expr(value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='woof', ctx=Load()), args=[Call(func=Attribute(value=Attribute(value=Name(id='foo', ctx=Load()), attr='x', ctx=Load()), attr='y', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)], keywords=[], starargs=None, kwargs=None))

By this point, we get a fair idea that every part of our python statements are internally represented as an AST by some object of the ast class.

The Call object is what we want to explore, which appears as the child object of Expr or Assign objects in our case.

From the documentation:

class Call(func, args, keywords, starargs, kwargs)

A function call. func is the function, which will often be a Name or Attribute object.

Of the arguments:
- args holds a list of the arguments passed by position.
- keywords holds a list of keyword objects representing arguments passed by keyword.
- starargs and kwargs each hold a single node, for arguments passed as *args and **kwargs.

What we’ve learned so far:

  • Every statement is represented as an object of one of the ast classes
  • Function calls are ast.Call objects

What we need to be careful of:

  • The recursive nature of the AST: nested calls like foo(bar()) are represented by an ast.Call object within an ast.Call object.

Traversing the AST

It is possible to do a depth-first traversal of the node by sub-classing ast.NodeVisitor and implementing a visit_* method; which can be any of the ast classes like visit_Call, visit_Assign, visit_Expr, etc.

Let’s visit every function call.

>>> class FunctionCallVisitor(ast.NodeVisitor):
...     def visit_Call(self, node):
...         print ast.dump(node)
...
>>> FunctionCallVisitor().visit(tree)
Call(func=Name(id='open', ctx=Load()), args=[Str(s='file')], keywords=[], starargs=None, kwargs=None)
Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='bar', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)
Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='bar', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='meow', ctx=Load()), args=[Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='z', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)], keywords=[], starargs=None, kwargs=None))], starargs=None, kwargs=None)
Call(func=Attribute(value=Name(id='foo', ctx=Load()), attr='woof', ctx=Load()), args=[Call(func=Attribute(value=Attribute(value=Name(id='foo', ctx=Load()), attr='x', ctx=Load()), attr='y', ctx=Load()), args=[], keywords=[keyword(arg='arg', value=Name(id='data', ctx=Load()))], starargs=None, kwargs=None)], keywords=[], starargs=None, kwargs=None)

Now we need the function name for each of these. If we observe carefully, we’ll notice that for a straight forward call like open(), the function name can be found in Name(id='...'), but for function calls of modules like foo.bar(), the name is instead obtained from Name within Attribute and then attr in Attribute.

We’re almost there.

All we need to do now is recursively visit each of the Calls and build the name with attrs. Below is an implementation of this.

'''
Get all function calls from a python file

The MIT License (MIT)
Copyright (c) 2016 Suhas S G <[email protected]>
'''
import ast
from collections import deque


class FuncCallVisitor(ast.NodeVisitor):
    def __init__(self):
        self._name = deque()

    @property
    def name(self):
        return '.'.join(self._name)

    @name.deleter
    def name(self):
        self._name.clear()

    def visit_Name(self, node):
        self._name.appendleft(node.id)

    def visit_Attribute(self, node):
        try:
            self._name.appendleft(node.attr)
            self._name.appendleft(node.value.id)
        except AttributeError:
            self.generic_visit(node)


def get_func_calls(tree):
    func_calls = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            callvisitor = FuncCallVisitor()
            callvisitor.visit(node.func)
            func_calls.append(callvisitor.name)

    return func_calls

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', help='Input .py file', required=True)
    args = parser.parse_args()
    tree = ast.parse(open(args.input).read())
    print get_func_calls(tree)

Running the above code gives the output we initially aimed for.

$ python function_calls_ast.py -i sample.py
['open', 'foo.bar', 'foo.bar', 'foo.woof', 'foo.x.y', 'foo.meow', 'foo.z']

Now that we’ve comfortably parsed a .py file and extracted the set of function calls, the possibilities ahead are exciting and endless!