From db12a14d959e038cae4fc9246c594b3c22ae6337 Mon Sep 17 00:00:00 2001 From: phongchen Date: Wed, 4 Oct 2023 01:17:50 +0800 Subject: [PATCH] Add override --- src/blade/backend.py | 2 +- src/blade/toolchain.py | 49 +++++++++++++++++++++++------------------- src/blade/util.py | 48 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 23 deletions(-) diff --git a/src/blade/backend.py b/src/blade/backend.py index 6f2fc67f..6c26f44c 100644 --- a/src/blade/backend.py +++ b/src/blade/backend.py @@ -666,7 +666,7 @@ def _builtin_tools_prefix(self): with open(builtin_tools_file, 'w') as f: print('@echo off', file=f) print('set "PYTHONPATH={};%PYTHONPATH%"'.format(self.blade_path), file=f) - print('{} %*'.format(python), file=f) + print('"{}" %*'.format(python), file=f) self.__builtin_tools_prefix = builtin_tools_file else: # On posix system, a simply environment prefix is enough to do it. diff --git a/src/blade/toolchain.py b/src/blade/toolchain.py index c9f3873f..e4f4786b 100644 --- a/src/blade/toolchain.py +++ b/src/blade/toolchain.py @@ -20,7 +20,7 @@ from blade import config from blade import console -from blade.util import var_to_list, iteritems, run_command +from blade.util import var_to_list, iteritems, override, run_command # example: Cuda compilation tools, release 11.0, V11.0.194 _nvcc_version_re = re.compile(r'V(\d+\.\d+\.\d+)') @@ -166,12 +166,23 @@ def object_file_of(self, source_file): """ raise NotImplementedError - def library_file_name(self, name): + def static_library_name(self, name): + """ + Get the static library file name from the name. + """ + raise NotImplementedError + + def dynamic_library_name(self, name): """ Get the library file name from the name. """ raise NotImplementedError + def executable_file_name(self, name): + """ + Get the executable file name from the name. + """ + raise NotImplementedError # To verify whether a header file is included without depends on the library it belongs to, # we use the gcc's `-H` option to generate the inclusion stack information, see @@ -227,16 +238,20 @@ def get_cc_target_arch(): return stdout.strip() return '' + @override def is_kind_of(self, vendor): - """Is cc is used for C/C++ compilation match vendor.""" return vendor in ('gcc', 'clang', 'gcc') + @override def object_file_of(self, source_file): - """ - Get the object file name from the source file. - """ return source_file + '.o' + @override + def executable_file_name(self, name): + if os.name == 'nt': + return name + '.exe' + return name + def _cc_compile_command_wrapper_template(self, inclusion_stack_file, cuda=False): """Calculate the cc compile command wrapper template.""" print_header_option = '-H' @@ -380,28 +395,17 @@ def get_cc_target_arch(): return stdout.strip() return '' - def get_cc_commands(self): - return self.cc, self.cxx, self.ld - - def get_cc(self): - return self.cc - - def get_cc_version(self): - return self.cc_version - - def get_ar(self): - return self.ar - + @override def is_kind_of(self, vendor): - """Is cc is used for C/C++ compilation match vendor.""" return vendor in ('msvc') + @override def object_file_of(self, source_file): - """ - Get the object file name from the source file. - """ return source_file + '.obj' + @override + def executable_file_name(self, name): + return name + '.exe' def filter_cc_flags(self, flag_list, language='c'): """Filter out the unrecognized compilation flags.""" @@ -494,6 +498,7 @@ def get_shared_link_command(self): def _get_link_args(self): return ' /nologo /OUT:${out} ${intrinsic_linkflags} ${linkflags} ${target_linkflags} @${out}.rsp ${extra_linkflags}' + def default(bits): if os.name == 'nt': return CcToolChainMsvc() diff --git a/src/blade/util.py b/src/blade/util.py index a618cb61..466cc873 100644 --- a/src/blade/util.py +++ b/src/blade/util.py @@ -21,6 +21,7 @@ import inspect import json import os +import re import signal import string import subprocess @@ -379,3 +380,50 @@ def which(cmd): if returncode != 0: return None return stdout.strip() + + +def override(method): + """ + Check method override. + https://stackoverflow.com/a/14631397 + """ + + stack = inspect.stack() + base_classes = _get_bass_classes(stack) + + # TODO: check signature + # sig = inspect.signature(method) + error = "methid '%s' doesn't override any base class method" % method.__name__ + assert( any( hasattr(cls, method.__name__) for cls in base_classes ) ), error + return method + + +def _get_bass_classes(stack): + base_classes = re.search(r'class.+\((.+)\)\s*\:', stack[2][4][0]).group(1) + + # handle multiple inheritance + base_classes = [s.strip() for s in base_classes.split(',')] + if not base_classes: + raise ValueError('override decorator: unable to determine base class') + + # stack[0]=override, stack[1]=inside class def'n, stack[2]=outside class def'n + derived_class_locals = stack[2][0].f_locals + + # replace each class name in base_classes with the actual class type + for i, base_class in enumerate(base_classes): + + if '.' not in base_class: + base_classes[i] = derived_class_locals[base_class] + else: + components = base_class.split('.') + + # obj is either a module or a class + obj = derived_class_locals[components[0]] + + for c in components[1:]: + assert(inspect.ismodule(obj) or inspect.isclass(obj)) + obj = getattr(obj, c) + + base_classes[i] = obj + + return base_classes