Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "solidity-antlr4"]
path = solidity-antlr4
url = https://github.com/solidityj/solidity-antlr4.git
url = https://github.com/solidity-parser/antlr.git
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# python-solidity-parser
A Solidity parser for Python built on top of a robust ANTLR4 grammar
An experimental Solidity parser for Python built on top of a robust ANTLR4 grammar.

**ⓘ** This is a **python3** port of the [javascript antlr parser](https://github.com/federicobond/solidity-parser-antlr) maintained by [@federicobond](https://github.com/federicobond/). Interfaces are intentionally following the javascript implementation and are therefore not pep8 compliant.



## Install

```
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
antlr4-python3-runtime
antlr4-python3-runtime==4.9.3
7 changes: 4 additions & 3 deletions scripts/antlr4.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
set -o errexit

antlr -Dlanguage=Python3 solidity-antlr4/Solidity.g4 -o src -visitor
mv src/solidity-antlr4/* src/solidity_antlr4

mv src/solidity-antlr4/* solidity_parser/solidity_antlr4
rm -rf src/solidity-antlr4

touch src/solidity_antlr4/__init__.py
touch src/solidity_antlr4/__AUTOGENERATED__
touch solidity_parser/solidity_antlr4/__init__.py
touch solidity_parser/solidity_antlr4/__AUTOGENERATED__
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def read(fname):
long_description=read("README.md") if os.path.isfile("README.md") else "",
long_description_content_type='text/markdown',
#python setup.py register -r https://testpypi.python.org/pypi
install_requires=["antlr4-python3-runtime"],
install_requires=["antlr4-python3-runtime==4.9.3"],
#test_suite="nose.collector",
#tests_require=["nose"],
)
2 changes: 1 addition & 1 deletion solidity-antlr4
Submodule solidity-antlr4 updated 6 files
+6 −0 .gitignore
+15 −6 README.md
+125 −44 Solidity.g4
+19 −0 build.sh
+17 −4 run-tests.sh
+343 −2 test.sol
2 changes: 1 addition & 1 deletion solidity_parser/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
print("=== contract: " + contract_name)
level +=1

print(("\t" * level) + "=== Inherited Contrracts: " + ','.join([bc.baseName.namePath for bc in contract_object._node.baseContracts]))
print(("\t" * level) + "=== Inherited Contracts: " + ','.join([bc.baseName.namePath for bc in contract_object._node.baseContracts]))
## statevars
print(("\t" * level) + "=== Enums")
level += 2
Expand Down
180 changes: 98 additions & 82 deletions solidity_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,27 @@ def visitEnumValue(self, ctx):
type="EnumValue",
name=ctx.identifier().getText())

def visitTypeDefinition(self, ctx):
return Node(ctx=ctx,
type="TypeDefinition",
typeKeyword=ctx.TypeKeyword().getText(),
elementaryTypeName=self.visit(ctx.elementaryTypeName()))


def visitCustomErrorDefinition(self, ctx):
return Node(ctx=ctx,
type="CustomErrorDefinition",
name=self.visit(ctx.identifier()),
parameterList=self.visit(ctx.parameterList()))

def visitFileLevelConstant(self, ctx):
return Node(ctx=ctx,
type="FileLevelConstant",
name=self.visit(ctx.identifier()),
typeName=self.visit(ctx.typeName()),
ConstantKeyword=self.visit(ctx.ConstantKeyword()))


def visitUsingForDeclaration(self, ctx: SolidityParser.UsingForDeclarationContext):
typename = None
if ctx.getChild(3) != '*':
Expand All @@ -138,45 +159,29 @@ def visitInheritanceSpecifier(self, ctx: SolidityParser.InheritanceSpecifierCont
return Node(ctx=ctx,
type="InheritanceSpecifier",
baseName=self.visit(ctx.userDefinedTypeName()),
arguments=self.visit(ctx.expression()))
arguments=self.visit(ctx.expressionList()))

def visitContractPart(self, ctx: SolidityParser.ContractPartContext):
return self.visit(ctx.children[0])

def visitConstructorDefinition(self, ctx: SolidityParser.ConstructorDefinitionContext):
parameters = self.visit(ctx.parameterList())
block = self.visit(ctx.block()) if ctx.block() else []
modifiers = [self.visit(i) for i in ctx.modifierList().modifierInvocation()]

if ctx.modifierList().ExternalKeyword(0):
visibility = "external"
elif ctx.modifierList().InternalKeyword(0):
visibility = "internal"
elif ctx.modifierList().PublicKeyword(0):
visibility = "public"
elif ctx.modifierList().PrivateKeyword(0):
visibility = "private"
else:
visibility = 'default'

if ctx.modifierList().stateMutability(0):
stateMutability = ctx.modifierList().stateMutability(0).getText()
def visitFunctionDefinition(self, ctx: SolidityParser.FunctionDefinitionContext):
isConstructor = isFallback =isReceive = False

fd = ctx.functionDescriptor()
if fd.ConstructorKeyword():
name = fd.ConstructorKeyword().getText()
isConstructor = True
elif fd.FallbackKeyword():
name = fd.FallbackKeyword().getText()
isFallback = True
elif fd.ReceiveKeyword():
name = fd.ReceiveKeyword().getText()
isReceive = True
elif fd.identifier():
name = fd.identifier().getText()
else:
stateMutability = None

return Node(ctx=ctx,
type="FunctionDefinition",
name=None,
parameters=parameters,
returnParameters=None,
body=block,
visibility=visibility,
modifiers=modifiers,
isConstructor=True,
stateMutability=stateMutability)

def visitFunctionDefinition(self, ctx: SolidityParser.ConstructorDefinitionContext):
name = ctx.identifier().getText() if ctx.identifier() else ""
raise Exception("unexpected function descriptor")

parameters = self.visit(ctx.parameterList())
returnParameters = self.visit(ctx.returnParameters()) if ctx.returnParameters() else []
Expand Down Expand Up @@ -207,7 +212,9 @@ def visitFunctionDefinition(self, ctx: SolidityParser.ConstructorDefinitionConte
body=block,
visibility=visibility,
modifiers=modifiers,
isConstructor=name == self._currentContract,
isConstructor=isConstructor,
isFallback=isFallback,
isReceive=isReceive,
stateMutability=stateMutability)

def visitReturnParameters(self, ctx: SolidityParser.ReturnParametersContext):
Expand Down Expand Up @@ -319,6 +326,10 @@ def visitEmitStatement(self, ctx):
type='EmitStatement',
eventCall=self.visit(ctx.getChild(1)))

def visitThrowStatement(self, ctx):
return Node(ctx=ctx,
type='ThrowStatement')

def visitStructDefinition(self, ctx):
return Node(ctx=ctx,
type='StructDefinition',
Expand Down Expand Up @@ -393,6 +404,21 @@ def visitIfStatement(self, ctx):
TrueBody=TrueBody,
FalseBody=FalseBody)

def visitTryStatement(self, ctx):
return Node(ctx=ctx,
type='TryStatement',
expression=self.visit(ctx.expression()),
block=self.visit(ctx.block()),
returnParameters=self.visit(ctx.returnParameters()),
catchClause=self.visit(ctx.catchClause()))

def visitCatchClause(self, ctx):
return Node(ctx=ctx,
type='CatchClause',
identifier=self.visit(ctx.identifier()),
parameterList=self.visit(ctx.parameterList()),
block=self.visit(ctx.block()))

def visitUserDefinedTypeName(self, ctx):
return Node(ctx=ctx,
type='UserDefinedTypeName',
Expand Down Expand Up @@ -428,7 +454,7 @@ def visitNumberLiteral(self, ctx):
def visitMapping(self, ctx):
return Node(ctx=ctx,
type='Mapping',
keyType=self.visit(ctx.elementaryTypeName()),
keyType=self.visit(ctx.mappingKey()),
valueType=self.visit(ctx.typeName()))

def visitModifierDefinition(self, ctx):
Expand All @@ -449,6 +475,16 @@ def visitStatement(self, ctx):
def visitSimpleStatement(self, ctx):
return self.visit(ctx.getChild(0))

def visitUncheckedStatement(self, ctx):
return Node(ctx=ctx,
type='UncheckedStatement',
body=self.visit(ctx.block()))

def visitRevertStatement(self, ctx):
return Node(ctx=ctx,
type='RevertStatement',
functionCall=self.visit(ctx.functionCall()))

def visitExpression(self, ctx):

children_length = len(ctx.children)
Expand Down Expand Up @@ -641,16 +677,15 @@ def visitPrimaryExpression(self, ctx):
type='BooleanLiteral',
value=ctx.BooleanLiteral().getText() == 'true')

if ctx.HexLiteral():
if ctx.hexLiteral():
return Node(ctx=ctx,
type='HexLiteral',
value=ctx.HexLiteral().getText())
type='hexLiteral',
value=ctx.hexLiteral().getText())

if ctx.StringLiteral():
if ctx.stringLiteral():
text = ctx.getText()

return Node(ctx=ctx,
type='StringLiteral',
type='stringLiteral',
value=text[1: len(text) - 1])

if len(ctx.children) == 3 and ctx.getChild(1).getText() == '[' and ctx.getChild(2).getText() == ']':
Expand Down Expand Up @@ -737,32 +772,6 @@ def visitVariableDeclarationStatement(self, ctx):
variables=variables,
initialValue=initialValue)

def visitImportDirective(self, ctx):
pathString = ctx.StringLiteral().getText()
unitAlias = None
symbolAliases = None

impDecLen = len(ctx.importDeclaration())
if impDecLen > 0:
symbolAliases = []
for decl in ctx.importDeclaration():
symbol = decl.identifier(0).getText()
alias = None
if decl.identifier(1):
alias = decl.identifier(1).getText()

symbolAliases.append([symbol, alias])
elif impDecLen == 7:
unitAlias = ctx.getChild(3).getText()
elif impDecLen == 5:
unitAlias = ctx.getChild(3).getText()

return Node(ctx=ctx,
type='ImportDirective',
path=pathString[1: len(pathString) - 1],
unitAlias=unitAlias,
symbolAliases=symbolAliases)

def visitEventDefinition(self, ctx):
return Node(ctx=ctx,
type='EventDefinition',
Expand Down Expand Up @@ -792,8 +801,8 @@ def visitEventParameterList(self, ctx):
def visitInlineAssemblyStatement(self, ctx):
language = None

if ctx.StringLiteral():
language = ctx.StringLiteral().getText()
if ctx.StringLiteralFragment():
language = ctx.StringLiteralFragment().getText()
language = language[1: len(language) - 1]

return Node(ctx=ctx,
Expand All @@ -810,13 +819,13 @@ def visitAssemblyBlock(self, ctx):

def visitAssemblyItem(self, ctx):

if ctx.HexLiteral():
if ctx.hexLiteral():
return Node(ctx=ctx,
type='HexLiteral',
value=ctx.HexLiteral().getText())
value=ctx.hexLiteral().getText())

if ctx.StringLiteral():
text = ctx.StringLiteral().getText()
if ctx.stringLiteral():
text = ctx.stringLiteral().getText()
return Node(ctx=ctx,
type='StringLiteral',
value=text[1: len(text) - 1])
Expand All @@ -834,6 +843,11 @@ def visitAssemblyItem(self, ctx):
def visitAssemblyExpression(self, ctx):
return self.visit(ctx.getChild(0))

def visitAssemblyMember(self, ctx):
return Node(ctx=ctx,
type='AssemblyMember',
name=ctx.identifier().getText())

def visitAssemblyCall(self, ctx):
functionName = ctx.getChild(0).getText()
args = [self.visit(arg) for arg in ctx.assemblyExpression()]
Expand All @@ -845,7 +859,7 @@ def visitAssemblyCall(self, ctx):

def visitAssemblyLiteral(self, ctx):

if ctx.StringLiteral():
if ctx.stringLiteral():
text = ctx.getText()
return Node(ctx=ctx,
type='StringLiteral',
Expand All @@ -861,7 +875,7 @@ def visitAssemblyLiteral(self, ctx):
type='HexNumber',
value=ctx.getText())

if ctx.HexLiteral():
if ctx.hexLiteral():
return Node(ctx=ctx,
type='HexLiteral',
value=ctx.getText())
Expand Down Expand Up @@ -981,7 +995,7 @@ def visitImportDirective(self, ctx):

return Node(ctx=ctx,
type="ImportDirective",
path=ctx.StringLiteral().getText().strip('"'),
path=ctx.importPath().getText().strip('"'),
symbolAliases=symbol_aliases,
unitAlias=unit_alias
)
Expand Down Expand Up @@ -1024,7 +1038,7 @@ def parse(text, start="sourceUnit", loc=False, strict=False):


def parse_file(path, start="sourceUnit", loc=False, strict=False):
with open(path, 'r') as f:
with open(path, 'r', encoding="utf-8") as f:
return parse(f.read(), start=start, loc=loc, strict=strict)


Expand Down Expand Up @@ -1106,10 +1120,6 @@ def visitStructDefinition(self, _node):
self.structs[_node.name]=_node
self.names[_node.name]=_node

def visitConstructorDefinition(self, _node):
self.constructor = _node


def visitStateVariableDeclaration(self, _node):

class VarDecVisitor(object):
Expand Down Expand Up @@ -1150,10 +1160,15 @@ def __init__(self, node):
if(node.type=="FunctionDefinition"):
self.visibility = node.visibility
self.stateMutability = node.stateMutability
self.isConstructor = node.isConstructor
self.isFallback = node.isFallback
self.isReceive = node.isReceive
self.arguments = {}
self.returns = {}
self.declarations = {}
self.identifiers = []



class FunctionArgumentVisitor(object):

Expand Down Expand Up @@ -1182,13 +1197,14 @@ def visitIdentifier(self, __node):
def visitAssemblyCall(self, __node):
self.idents.append(__node)


current_function = FunctionObject(_node)
self.names[_node.name] = current_function
if _definition_type=="ModifierDefinition":
self.modifiers[_node.name] = current_function
else:
self.functions[_node.name] = current_function
if current_function.isConstructor:
self.constructor = current_function

## get parameters
funcargvisitor = FunctionArgumentVisitor()
Expand Down
Loading