480 lines
15 KiB
Python
480 lines
15 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding:utf8 -*-
|
|
# protoc-gen-erl
|
|
# Google's Protocol Buffers project, ported to lua.
|
|
# https://code.google.com/p/protoc-gen-lua/
|
|
#
|
|
# Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
|
|
# All rights reserved.
|
|
#
|
|
# Use, modification and distribution are subject to the "New BSD License"
|
|
# as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
|
|
|
|
import sys
|
|
import os.path as path
|
|
from cStringIO import StringIO
|
|
|
|
import plugin_pb2
|
|
import google.protobuf.descriptor_pb2 as descriptor_pb2
|
|
|
|
_packages = {}
|
|
_files = {}
|
|
_message = {}
|
|
|
|
FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto
|
|
|
|
if sys.platform == "win32":
|
|
import msvcrt, os
|
|
msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
|
|
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
|
|
|
|
class CppType:
|
|
CPPTYPE_INT32 = 1
|
|
CPPTYPE_INT64 = 2
|
|
CPPTYPE_UINT32 = 3
|
|
CPPTYPE_UINT64 = 4
|
|
CPPTYPE_DOUBLE = 5
|
|
CPPTYPE_FLOAT = 6
|
|
CPPTYPE_BOOL = 7
|
|
CPPTYPE_ENUM = 8
|
|
CPPTYPE_STRING = 9
|
|
CPPTYPE_MESSAGE = 10
|
|
|
|
CPP_TYPE ={
|
|
FDP.TYPE_DOUBLE : CppType.CPPTYPE_DOUBLE,
|
|
FDP.TYPE_FLOAT : CppType.CPPTYPE_FLOAT,
|
|
FDP.TYPE_INT64 : CppType.CPPTYPE_INT64,
|
|
FDP.TYPE_UINT64 : CppType.CPPTYPE_UINT64,
|
|
FDP.TYPE_INT32 : CppType.CPPTYPE_INT32,
|
|
FDP.TYPE_FIXED64 : CppType.CPPTYPE_UINT64,
|
|
FDP.TYPE_FIXED32 : CppType.CPPTYPE_UINT32,
|
|
FDP.TYPE_BOOL : CppType.CPPTYPE_BOOL,
|
|
FDP.TYPE_STRING : CppType.CPPTYPE_STRING,
|
|
FDP.TYPE_MESSAGE : CppType.CPPTYPE_MESSAGE,
|
|
FDP.TYPE_BYTES : CppType.CPPTYPE_STRING,
|
|
FDP.TYPE_UINT32 : CppType.CPPTYPE_UINT32,
|
|
FDP.TYPE_ENUM : CppType.CPPTYPE_ENUM,
|
|
FDP.TYPE_SFIXED32 : CppType.CPPTYPE_INT32,
|
|
FDP.TYPE_SFIXED64 : CppType.CPPTYPE_INT64,
|
|
FDP.TYPE_SINT32 : CppType.CPPTYPE_INT32,
|
|
FDP.TYPE_SINT64 : CppType.CPPTYPE_INT64
|
|
}
|
|
|
|
def printerr(*args):
|
|
sys.stderr.write(" ".join(args))
|
|
sys.stderr.write("\n")
|
|
sys.stderr.flush()
|
|
|
|
class TreeNode(object):
|
|
def __init__(self, name, parent=None, filename=None, package=None):
|
|
super(TreeNode, self).__init__()
|
|
self.child = []
|
|
self.parent = parent
|
|
self.filename = filename
|
|
self.package = package
|
|
if parent:
|
|
self.parent.add_child(self)
|
|
self.name = name
|
|
|
|
def add_child(self, child):
|
|
self.child.append(child)
|
|
|
|
def find_child(self, child_names):
|
|
if child_names:
|
|
for i in self.child:
|
|
if i.name == child_names[0]:
|
|
return i.find_child(child_names[1:])
|
|
raise StandardError
|
|
else:
|
|
return self
|
|
|
|
def get_child(self, child_name):
|
|
for i in self.child:
|
|
if i.name == child_name:
|
|
return i
|
|
return None
|
|
|
|
def get_path(self, end = None):
|
|
pos = self
|
|
out = []
|
|
while pos and pos != end:
|
|
out.append(pos.name)
|
|
pos = pos.parent
|
|
out.reverse()
|
|
return '.'.join(out)
|
|
|
|
def get_global_name(self):
|
|
return self.get_path()
|
|
|
|
def get_local_name(self):
|
|
pos = self
|
|
while pos.parent:
|
|
pos = pos.parent
|
|
if self.package and pos.name == self.package[-1]:
|
|
break
|
|
return self.get_path(pos)
|
|
|
|
def __str__(self):
|
|
return self.to_string(0)
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def to_string(self, indent = 0):
|
|
return ' '*indent + '<TreeNode ' + self.name + '(\n' + \
|
|
','.join([i.to_string(indent + 4) for i in self.child]) + \
|
|
' '*indent +')>\n'
|
|
|
|
class Env(object):
|
|
filename = None
|
|
package = None
|
|
extend = None
|
|
descriptor = None
|
|
message = None
|
|
context = None
|
|
register = None
|
|
def __init__(self):
|
|
self.message_tree = TreeNode('')
|
|
self.scope = self.message_tree
|
|
|
|
def get_global_name(self):
|
|
return self.scope.get_global_name()
|
|
|
|
def get_local_name(self):
|
|
return self.scope.get_local_name()
|
|
|
|
def get_ref_name(self, type_name):
|
|
try:
|
|
node = self.lookup_name(type_name)
|
|
except:
|
|
# if the child doesn't be founded, it must be in this file
|
|
return type_name[len('.'.join(self.package)) + 2:]
|
|
if node.filename != self.filename:
|
|
return node.filename + '_pb.' + node.get_local_name()
|
|
return node.get_local_name()
|
|
|
|
def lookup_name(self, name):
|
|
names = name.split('.')
|
|
if names[0] == '':
|
|
return self.message_tree.find_child(names[1:])
|
|
else:
|
|
return self.scope.parent.find_child(names)
|
|
|
|
def enter_package(self, package):
|
|
if not package:
|
|
return self.message_tree
|
|
names = package.split('.')
|
|
pos = self.message_tree
|
|
for i, name in enumerate(names):
|
|
new_pos = pos.get_child(name)
|
|
if new_pos:
|
|
pos = new_pos
|
|
else:
|
|
return self._build_nodes(pos, names[i:])
|
|
return pos
|
|
|
|
def enter_file(self, filename, package):
|
|
self.filename = filename
|
|
self.package = package.split('.')
|
|
self._init_field()
|
|
self.scope = self.enter_package(package)
|
|
|
|
def exit_file(self):
|
|
self._init_field()
|
|
self.filename = None
|
|
self.package = []
|
|
self.scope = self.scope.parent
|
|
|
|
def enter(self, message_name):
|
|
self.scope = TreeNode(message_name, self.scope, self.filename,
|
|
self.package)
|
|
|
|
def exit(self):
|
|
self.scope = self.scope.parent
|
|
|
|
def _init_field(self):
|
|
self.descriptor = []
|
|
self.context = []
|
|
self.message = []
|
|
self.register = []
|
|
|
|
def _build_nodes(self, node, names):
|
|
parent = node
|
|
for i in names:
|
|
parent = TreeNode(i, parent, self.filename, self.package)
|
|
return parent
|
|
|
|
class Writer(object):
|
|
def __init__(self, prefix=None):
|
|
self.io = StringIO()
|
|
self.__indent = ''
|
|
self.__prefix = prefix
|
|
|
|
def getvalue(self):
|
|
return self.io.getvalue()
|
|
|
|
def __enter__(self):
|
|
self.__indent += ' '
|
|
return self
|
|
|
|
def __exit__(self, type, value, trackback):
|
|
self.__indent = self.__indent[:-4]
|
|
|
|
def __call__(self, data):
|
|
self.io.write(self.__indent)
|
|
if self.__prefix:
|
|
self.io.write(self.__prefix)
|
|
self.io.write(data)
|
|
|
|
DEFAULT_VALUE = {
|
|
FDP.TYPE_DOUBLE : '0.0',
|
|
FDP.TYPE_FLOAT : '0.0',
|
|
FDP.TYPE_INT64 : '0',
|
|
FDP.TYPE_UINT64 : '0',
|
|
FDP.TYPE_INT32 : '0',
|
|
FDP.TYPE_FIXED64 : '0',
|
|
FDP.TYPE_FIXED32 : '0',
|
|
FDP.TYPE_BOOL : 'false',
|
|
FDP.TYPE_STRING : '""',
|
|
FDP.TYPE_MESSAGE : 'nil',
|
|
FDP.TYPE_BYTES : '""',
|
|
FDP.TYPE_UINT32 : '0',
|
|
FDP.TYPE_ENUM : '1',
|
|
FDP.TYPE_SFIXED32 : '0',
|
|
FDP.TYPE_SFIXED64 : '0',
|
|
FDP.TYPE_SINT32 : '0',
|
|
FDP.TYPE_SINT64 : '0',
|
|
}
|
|
|
|
def code_gen_enum_item(index, enum_value, env):
|
|
full_name = env.get_local_name() + '.' + enum_value.name
|
|
obj_name = 'pb_desc.' + full_name.upper().replace('.', '_') + '_ENUM'
|
|
env.descriptor.append(
|
|
"%s = protobuf.EnumValueDescriptor();\n"% obj_name
|
|
)
|
|
|
|
context = Writer(obj_name)
|
|
context('.name = "%s"\n' % enum_value.name)
|
|
context('.index = %d\n' % index)
|
|
context('.number = %d\n' % enum_value.number)
|
|
|
|
env.context.append(context.getvalue())
|
|
return obj_name
|
|
|
|
def code_gen_enum(enum_desc, env):
|
|
env.enter(enum_desc.name)
|
|
full_name = env.get_local_name()
|
|
obj_name = 'pb_desc.' + full_name.upper().replace('.', '_')
|
|
env.descriptor.append(
|
|
"%s = protobuf.EnumDescriptor();\n"% obj_name
|
|
)
|
|
|
|
context = Writer(obj_name)
|
|
context('.name = "%s"\n' % enum_desc.name)
|
|
context('.full_name = "%s"\n' % env.get_global_name())
|
|
|
|
values = []
|
|
for i, enum_value in enumerate(enum_desc.value):
|
|
values.append(code_gen_enum_item(i, enum_value, env))
|
|
context('.values = {%s}\n' % ','.join(values))
|
|
|
|
env.context.append(context.getvalue())
|
|
env.exit()
|
|
return obj_name
|
|
|
|
def code_gen_field(index, field_desc, env):
|
|
full_name = env.get_local_name() + '.' + field_desc.name
|
|
obj_name = 'pb_desc.' + full_name.upper().replace('.', '_') + '_FIELD'
|
|
env.descriptor.append(
|
|
"%s = protobuf.FieldDescriptor();\n"% obj_name
|
|
)
|
|
|
|
context = Writer(obj_name)
|
|
|
|
context('.name = "%s"\n' % field_desc.name)
|
|
context('.full_name = "%s"\n' % (
|
|
env.get_global_name() + '.' + field_desc.name))
|
|
context('.number = %d\n' % field_desc.number)
|
|
context('.index = %d\n' % index)
|
|
context('.label = %d\n' % field_desc.label)
|
|
|
|
if field_desc.HasField("default_value"):
|
|
context('.has_default_value = true\n')
|
|
value = field_desc.default_value
|
|
if field_desc.type == FDP.TYPE_STRING:
|
|
context('.default_value = "%s"\n'%value)
|
|
else:
|
|
context('.default_value = %s\n'%value)
|
|
else:
|
|
context('.has_default_value = false\n')
|
|
if field_desc.label == FDP.LABEL_REPEATED:
|
|
default_value = "{}"
|
|
elif field_desc.HasField('type_name'):
|
|
default_value = "nil"
|
|
else:
|
|
default_value = DEFAULT_VALUE[field_desc.type]
|
|
context('.default_value = %s\n' % default_value)
|
|
|
|
if field_desc.HasField('type_name'):
|
|
# context("wing type:%s\n"%field_desc.type_name)
|
|
# context("wing ref type:%s\n"%env.get_ref_name(field_desc.type_name))
|
|
type_name = env.get_ref_name(field_desc.type_name).upper()
|
|
type_names = type_name.split(".")
|
|
if len(type_names) > 1:
|
|
type_name = type_names[0].replace('/', '_').lower() + ".GLOBAL_" + type_names[1]
|
|
# type_name = env.get_ref_name(field_desc.type_name).upper().replace('.', '_')
|
|
if field_desc.type == FDP.TYPE_MESSAGE:
|
|
context('.message_type = pb_desc.%s\n' % type_name)
|
|
else:
|
|
context('.enum_type = pb_desc.%s\n' % type_name)
|
|
|
|
if field_desc.HasField('extendee'):
|
|
type_name = env.get_ref_name(field_desc.extendee)
|
|
env.register.append(
|
|
"%s.RegisterExtension(%s)\n" % (type_name, obj_name)
|
|
)
|
|
|
|
context('.type = %d\n' % field_desc.type)
|
|
context('.cpp_type = %d\n\n' % CPP_TYPE[field_desc.type])
|
|
env.context.append(context.getvalue())
|
|
return obj_name
|
|
|
|
def code_gen_message(message_descriptor, env, containing_type = None):
|
|
env.enter(message_descriptor.name)
|
|
full_name = env.get_local_name()
|
|
global_obj_name = full_name.upper().replace('.', '_')
|
|
obj_name = 'pb_desc.' + global_obj_name
|
|
env.descriptor.append(
|
|
"%s = protobuf.Descriptor();\n"% obj_name
|
|
)
|
|
env.descriptor.append("GLOBAL_%s = %s;\n" % (global_obj_name, obj_name))
|
|
# for i, field_desc in enumerate(message_descriptor.field):
|
|
# if field_desc.HasField("type_name"):
|
|
# env.descriptor.append("test tttt %s = %s;\n" % (i, field_desc.type_name))
|
|
|
|
context = Writer(obj_name)
|
|
context('.name = "%s"\n' % message_descriptor.name)
|
|
context('.full_name = "%s"\n' % env.get_global_name())
|
|
|
|
nested_types = []
|
|
for msg_desc in message_descriptor.nested_type:
|
|
msg_name = code_gen_message(msg_desc, env, obj_name)
|
|
nested_types.append(msg_name)
|
|
context('.nested_types = {%s}\n' % ', '.join(nested_types))
|
|
|
|
enums = []
|
|
for enum_desc in message_descriptor.enum_type:
|
|
enums.append(code_gen_enum(enum_desc, env))
|
|
context('.enum_types = {%s}\n' % ', '.join(enums))
|
|
|
|
fields = []
|
|
for i, field_desc in enumerate(message_descriptor.field):
|
|
fields.append(code_gen_field(i, field_desc, env))
|
|
|
|
context('.fields = {%s}\n' % ', '.join(fields))
|
|
if len(message_descriptor.extension_range) > 0:
|
|
context('.is_extendable = true\n')
|
|
else:
|
|
context('.is_extendable = false\n')
|
|
|
|
extensions = []
|
|
for i, field_desc in enumerate(message_descriptor.extension):
|
|
extensions.append(code_gen_field(i, field_desc, env))
|
|
context('.extensions = {%s}\n' % ', '.join(extensions))
|
|
|
|
if containing_type:
|
|
context('.containing_type = %s\n' % containing_type)
|
|
|
|
env.message.append('%s = protobuf.Message(%s),\n' % (full_name,
|
|
obj_name))
|
|
|
|
env.context.append(context.getvalue())
|
|
env.exit()
|
|
return obj_name
|
|
|
|
def write_header(writer):
|
|
writer("""-- Generated By protoc-gen-lua Do not Edit
|
|
""")
|
|
|
|
|
|
def code_gen_file(proto_file, env, is_gen):
|
|
filename = path.splitext(proto_file.name)[0]
|
|
env.enter_file(filename, proto_file.package)
|
|
|
|
includes = []
|
|
for f in proto_file.dependency:
|
|
inc_file = path.splitext(f)[0]
|
|
includes.append(inc_file)
|
|
|
|
# for field_desc in proto_file.extension:
|
|
# code_gen_extensions(field_desc, field_desc.name, env)
|
|
|
|
for enum_desc in proto_file.enum_type:
|
|
code_gen_enum(enum_desc, env)
|
|
for enum_value in enum_desc.value:
|
|
env.message.append('%s = %d,\n' % (enum_value.name,
|
|
enum_value.number))
|
|
|
|
for msg_desc in proto_file.message_type:
|
|
code_gen_message(msg_desc, env)
|
|
|
|
if is_gen:
|
|
lua = Writer()
|
|
|
|
write_header(lua)
|
|
lua('local protobuf = require "protobuf"\n')
|
|
lua('\n')
|
|
lua('local pb_desc = {}')
|
|
lua('\n\n')
|
|
for i in includes:
|
|
lua('pb_desc.%s_pb = require("%s_pb")\n' % (i.replace('/', '_'), i))
|
|
#lua("module('%s_pb')\n" % env.filename)
|
|
|
|
|
|
lua('\n\n')
|
|
map(lua, env.descriptor)
|
|
lua('\n')
|
|
|
|
map(lua, env.context)
|
|
lua('\n')
|
|
env.message.sort()
|
|
|
|
lua('local _M = {')
|
|
lua('\n')
|
|
map(lua, env.message)
|
|
lua('\n')
|
|
map(lua, env.register)
|
|
lua('\n')
|
|
lua('}')
|
|
|
|
lua('\n')
|
|
lua('local modname = "%s_pb"' % env.filename)
|
|
lua('\n')
|
|
lua('_G[modname] = _M')
|
|
lua('\n')
|
|
lua('return _M')
|
|
_files[env.filename+ '_pb.lua'] = lua.getvalue()
|
|
env.exit_file()
|
|
|
|
def main():
|
|
plugin_require_bin = sys.stdin.read()
|
|
code_gen_req = plugin_pb2.CodeGeneratorRequest()
|
|
code_gen_req.ParseFromString(plugin_require_bin)
|
|
|
|
env = Env()
|
|
for proto_file in code_gen_req.proto_file:
|
|
code_gen_file(proto_file, env,
|
|
proto_file.name in code_gen_req.file_to_generate)
|
|
|
|
code_generated = plugin_pb2.CodeGeneratorResponse()
|
|
for k in _files:
|
|
file_desc = code_generated.file.add()
|
|
file_desc.name = k
|
|
file_desc.content = _files[k]
|
|
|
|
sys.stdout.write(code_generated.SerializeToString())
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|