Skip to content

Commit

Permalink
Add override
Browse files Browse the repository at this point in the history
  • Loading branch information
phongchen committed Oct 3, 2023
1 parent 283d880 commit db12a14
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/blade/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _builtin_tools_prefix(self):
with open(builtin_tools_file, 'w') as f:
print('@echo off', file=f)
print('set "PYTHONPATH={};%PYTHONPATH%"'.format(self.blade_path), file=f)
print('{} %*'.format(python), file=f)
print('"{}" %*'.format(python), file=f)
self.__builtin_tools_prefix = builtin_tools_file
else:
# On posix system, a simply environment prefix is enough to do it.
Expand Down
49 changes: 27 additions & 22 deletions src/blade/toolchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from blade import config
from blade import console
from blade.util import var_to_list, iteritems, run_command
from blade.util import var_to_list, iteritems, override, run_command

# example: Cuda compilation tools, release 11.0, V11.0.194
_nvcc_version_re = re.compile(r'V(\d+\.\d+\.\d+)')
Expand Down Expand Up @@ -166,12 +166,23 @@ def object_file_of(self, source_file):
"""
raise NotImplementedError

def library_file_name(self, name):
def static_library_name(self, name):
"""
Get the static library file name from the name.
"""
raise NotImplementedError

def dynamic_library_name(self, name):
"""
Get the library file name from the name.
"""
raise NotImplementedError

def executable_file_name(self, name):
"""
Get the executable file name from the name.
"""
raise NotImplementedError

# To verify whether a header file is included without depends on the library it belongs to,
# we use the gcc's `-H` option to generate the inclusion stack information, see
Expand Down Expand Up @@ -227,16 +238,20 @@ def get_cc_target_arch():
return stdout.strip()
return ''

@override
def is_kind_of(self, vendor):
"""Is cc is used for C/C++ compilation match vendor."""
return vendor in ('gcc', 'clang', 'gcc')

@override
def object_file_of(self, source_file):
"""
Get the object file name from the source file.
"""
return source_file + '.o'

@override
def executable_file_name(self, name):
if os.name == 'nt':
return name + '.exe'
return name

def _cc_compile_command_wrapper_template(self, inclusion_stack_file, cuda=False):
"""Calculate the cc compile command wrapper template."""
print_header_option = '-H'
Expand Down Expand Up @@ -380,28 +395,17 @@ def get_cc_target_arch():
return stdout.strip()
return ''

def get_cc_commands(self):
return self.cc, self.cxx, self.ld

def get_cc(self):
return self.cc

def get_cc_version(self):
return self.cc_version

def get_ar(self):
return self.ar

@override
def is_kind_of(self, vendor):
"""Is cc is used for C/C++ compilation match vendor."""
return vendor in ('msvc')

@override
def object_file_of(self, source_file):
"""
Get the object file name from the source file.
"""
return source_file + '.obj'

@override
def executable_file_name(self, name):
return name + '.exe'

def filter_cc_flags(self, flag_list, language='c'):
"""Filter out the unrecognized compilation flags."""
Expand Down Expand Up @@ -494,6 +498,7 @@ def get_shared_link_command(self):
def _get_link_args(self):
return ' /nologo /OUT:${out} ${intrinsic_linkflags} ${linkflags} ${target_linkflags} @${out}.rsp ${extra_linkflags}'


def default(bits):
if os.name == 'nt':
return CcToolChainMsvc()
Expand Down
48 changes: 48 additions & 0 deletions src/blade/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import inspect
import json
import os
import re
import signal
import string
import subprocess
Expand Down Expand Up @@ -379,3 +380,50 @@ def which(cmd):
if returncode != 0:
return None
return stdout.strip()


def override(method):
"""
Check method override.
https://stackoverflow.com/a/14631397
"""

stack = inspect.stack()
base_classes = _get_bass_classes(stack)

# TODO: check signature
# sig = inspect.signature(method)
error = "methid '%s' doesn't override any base class method" % method.__name__
assert( any( hasattr(cls, method.__name__) for cls in base_classes ) ), error
return method


def _get_bass_classes(stack):
base_classes = re.search(r'class.+\((.+)\)\s*\:', stack[2][4][0]).group(1)

# handle multiple inheritance
base_classes = [s.strip() for s in base_classes.split(',')]
if not base_classes:
raise ValueError('override decorator: unable to determine base class')

# stack[0]=override, stack[1]=inside class def'n, stack[2]=outside class def'n
derived_class_locals = stack[2][0].f_locals

# replace each class name in base_classes with the actual class type
for i, base_class in enumerate(base_classes):

if '.' not in base_class:
base_classes[i] = derived_class_locals[base_class]
else:
components = base_class.split('.')

# obj is either a module or a class
obj = derived_class_locals[components[0]]

for c in components[1:]:
assert(inspect.ismodule(obj) or inspect.isclass(obj))
obj = getattr(obj, c)

base_classes[i] = obj

return base_classes

0 comments on commit db12a14

Please sign in to comment.