Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Remove StringRef from python binding functions #3775

Open
wants to merge 14 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions stdlib/src/collections/string/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,13 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
value: The string value.
"""

debug_assert(
_is_valid_utf8(value.as_bytes()), "value is not valid utf8"
# FIXME(#3706): problems at comp time
# debug_assert(
# _is_valid_utf8(value.as_bytes()), "value is not valid utf8"
# )
self = StringSlice[O](
ptr=value.unsafe_ptr(), length=value.byte_length()
)
self = StringSlice[O](unsafe_from_utf8=value.as_bytes())

# ===-------------------------------------------------------------------===#
# Factory methods
Expand Down
113 changes: 57 additions & 56 deletions stdlib/src/python/_cpython.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct PythonVersion:
"""The patch version number."""

@implicit
fn __init__(out self, version: StringRef):
fn __init__(out self, version: StringSlice[mut=False]):
"""Initialize a PythonVersion object from a version string.

Args:
Expand Down Expand Up @@ -287,8 +287,11 @@ struct PythonVersion:
self = PythonVersion(components[0], components[1], components[2])


fn _py_get_version(lib: DLHandle) -> StringRef:
return StringRef(ptr=lib.call["Py_GetVersion", UnsafePointer[c_char]]())
fn _py_get_version(lib: DLHandle) -> StringSlice[ImmutableAnyOrigin]:
var ptr = lib.call["Py_GetVersion", UnsafePointer[c_char]]()
return StringSlice[ImmutableAnyOrigin](
unsafe_from_utf8_strref=StringRef(ptr=ptr)
)
martinvuyk marked this conversation as resolved.
Show resolved Hide resolved


fn _py_finalize(lib: DLHandle):
Expand Down Expand Up @@ -741,7 +744,7 @@ struct CPython:
"""The version of the Python runtime."""
var total_ref_count: UnsafePointer[Int]
"""The total reference count of all Python objects."""
var init_error: StringRef
var init_error: StringSlice[ImmutableAnyOrigin]
"""An error message if initialization failed."""

# ===-------------------------------------------------------------------===#
Expand Down Expand Up @@ -771,9 +774,12 @@ struct CPython:

# TODO(MOCO-772) Allow raises to propagate through function pointers
# and make this initialization a raising function.
self.init_error = external_call[
var ptr = external_call[
"KGEN_CompilerRT_Python_SetPythonPath", UnsafePointer[c_char]
]()
self.init_error = StringSlice[ImmutableAnyOrigin](
unsafe_from_utf8_strref=StringRef(ptr=ptr)
)
martinvuyk marked this conversation as resolved.
Show resolved Hide resolved

var python_lib = getenv("MOJO_PYTHON_LIBRARY")

Expand All @@ -787,7 +793,9 @@ struct CPython:
self.logging_enabled = logging_enabled
if not self.init_error:
if not self.lib.check_symbol("Py_Initialize"):
self.init_error = "compatible Python library not found"
self.init_error = rebind[__type_of(self.init_error)](
StringSlice("compatible Python library not found")
)
self.lib.call["Py_Initialize"]()
self.version = PythonVersion(_py_get_version(self.lib))
else:
Expand Down Expand Up @@ -1061,14 +1069,15 @@ struct CPython:
# ===-------------------------------------------------------------------===#

fn PyImport_ImportModule(
mut self,
name: StringRef,
mut self, name: StringSlice[mut=False]
) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/import.html#c.PyImport_ImportModule).
"""

var r = self.lib.call["PyImport_ImportModule", PyObjectPtr](name.data)
var r = self.lib.call["PyImport_ImportModule", PyObjectPtr](
name.unsafe_ptr().bitcast[c_char]()
)

self.log(
r._get_ptr_as_int(),
Expand All @@ -1081,7 +1090,9 @@ struct CPython:
self._inc_total_rc()
return r

fn PyImport_AddModule(mut self, name: StringRef) -> PyObjectPtr:
fn PyImport_AddModule(
mut self, name: StringSlice[mut=False]
) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/import.html#c.PyImport_AddModule).
"""
Expand Down Expand Up @@ -1188,7 +1199,7 @@ struct CPython:
# Python Evaluation
# ===-------------------------------------------------------------------===#

fn PyRun_SimpleString(mut self, strref: StringRef) -> Bool:
fn PyRun_SimpleString(mut self, strref: StringSlice[mut=False]) -> Bool:
"""Executes the given Python code.

Args:
Expand All @@ -1203,12 +1214,15 @@ struct CPython:
https://docs.python.org/3/c-api/veryhigh.html#c.PyRun_SimpleString).
"""
return (
self.lib.call["PyRun_SimpleString", c_int](strref.unsafe_ptr()) == 0
self.lib.call["PyRun_SimpleString", c_int](
strref.unsafe_ptr().bitcast[c_char]()
)
== 0
)

fn PyRun_String(
mut self,
strref: StringRef,
strref: StringSlice[mut=False],
globals: PyObjectPtr,
locals: PyObjectPtr,
run_mode: Int,
Expand All @@ -1217,7 +1231,10 @@ struct CPython:
https://docs.python.org/3/c-api/veryhigh.html#c.PyRun_String).
"""
var result = self.lib.call["PyRun_String", PyObjectPtr](
strref.unsafe_ptr(), Int32(run_mode), globals, locals
strref.unsafe_ptr().bitcast[c_char](),
Int32(run_mode),
globals,
locals,
)

self.log(
Expand Down Expand Up @@ -1256,16 +1273,18 @@ struct CPython:

fn Py_CompileString(
mut self,
strref: StringRef,
filename: StringRef,
strref: StringSlice[mut=False],
filename: StringSlice[mut=False],
compile_mode: Int,
) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/veryhigh.html#c.Py_CompileString).
"""

var r = self.lib.call["Py_CompileString", PyObjectPtr](
strref.unsafe_ptr(), filename.unsafe_ptr(), Int32(compile_mode)
strref.unsafe_ptr().bitcast[c_char](),
filename.unsafe_ptr().bitcast[c_char](),
Int32(compile_mode),
)
self._inc_total_rc()
return r
Expand Down Expand Up @@ -1361,16 +1380,14 @@ struct CPython:
return r

fn PyObject_GetAttrString(
mut self,
obj: PyObjectPtr,
name: StringRef,
mut self, obj: PyObjectPtr, name: StringSlice[mut=False]
) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/object.html#c.PyObject_GetAttrString).
"""

var r = self.lib.call["PyObject_GetAttrString", PyObjectPtr](
obj, name.data
obj, name.unsafe_ptr().bitcast[c_char]()
)

self.log(
Expand All @@ -1387,14 +1404,17 @@ struct CPython:
return r

fn PyObject_SetAttrString(
mut self, obj: PyObjectPtr, name: StringRef, new_value: PyObjectPtr
mut self,
obj: PyObjectPtr,
name: StringSlice[mut=False],
new_value: PyObjectPtr,
) -> c_int:
"""[Reference](
https://docs.python.org/3/c-api/object.html#c.PyObject_SetAttrString).
"""

var r = self.lib.call["PyObject_SetAttrString", c_int](
obj, name.data, new_value
obj, name.unsafe_ptr().bitcast[c_char](), new_value
)

self.log(
Expand All @@ -1411,9 +1431,7 @@ struct CPython:
return r

fn PyObject_CallObject(
mut self,
callable_obj: PyObjectPtr,
args: PyObjectPtr,
mut self, callable_obj: PyObjectPtr, args: PyObjectPtr
) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/call.html#c.PyObject_CallObject).
Expand Down Expand Up @@ -1711,15 +1729,17 @@ struct CPython:
# Unicode Objects
# ===-------------------------------------------------------------------===#

fn PyUnicode_DecodeUTF8(mut self, strref: StringRef) -> PyObjectPtr:
fn PyUnicode_DecodeUTF8(
mut self, strref: StringSlice[mut=False]
) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_DecodeUTF8).
"""

var r = self.lib.call["PyUnicode_DecodeUTF8", PyObjectPtr](
strref.unsafe_ptr().bitcast[Int8](),
strref.length,
"strict".unsafe_cstr_ptr(),
strref.unsafe_ptr().bitcast[c_char](),
strref.byte_length(),
"strict".unsafe_ptr().bitcast[c_char](),
)

self.log(
Expand All @@ -1733,27 +1753,6 @@ struct CPython:
self._inc_total_rc()
return r

fn PyUnicode_DecodeUTF8(mut self, strslice: StringSlice) -> PyObjectPtr:
"""[Reference](
https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_DecodeUTF8).
"""
var r = self.lib.call["PyUnicode_DecodeUTF8", PyObjectPtr](
strslice.unsafe_ptr().bitcast[Int8](),
strslice.byte_length(),
"strict".unsafe_cstr_ptr(),
)

self.log(
r._get_ptr_as_int(),
" NEWREF PyUnicode_DecodeUTF8, refcnt:",
self._Py_REFCNT(r),
", str:",
strslice,
)

self._inc_total_rc()
return r

fn PySlice_FromSlice(mut self, slice: Slice) -> PyObjectPtr:
# Convert Mojo Slice to Python slice parameters
# Note: Deliberately avoid using `span.indices()` here and instead pass
Expand Down Expand Up @@ -1781,16 +1780,18 @@ struct CPython:

return py_slice

fn PyUnicode_AsUTF8AndSize(mut self, py_object: PyObjectPtr) -> StringRef:
fn PyUnicode_AsUTF8AndSize(
mut self, py_object: PyObjectPtr
) -> StringSlice[MutableAnyOrigin]:
"""[Reference](
https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_AsUTF8AndSize).
"""

var s = StringRef()
s.data = self.lib.call[
var length = 0
var ptr = self.lib.call[
"PyUnicode_AsUTF8AndSize", UnsafePointer[c_char]
](py_object, UnsafePointer.address_of(s.length)).bitcast[UInt8]()
return s
](py_object, UnsafePointer.address_of(length)).bitcast[Byte]()
return StringSlice[MutableAnyOrigin](ptr=ptr, length=length)

# ===-------------------------------------------------------------------===#
# Python Error operations
Expand Down
18 changes: 12 additions & 6 deletions stdlib/src/python/python.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ from python import Python
```
"""

from collections import Dict
from collections import Dict, Optional
from os import abort, getenv
from sys import external_call, sizeof
from sys.ffi import _Global

from memory import UnsafePointer

from utils import StringRef
from utils import StringSlice, StaticString

from ._cpython import (
CPython,
Expand Down Expand Up @@ -99,7 +99,7 @@ struct Python:
"""
self.impl = existing.impl

fn eval(mut self, code: StringRef) -> Bool:
fn eval(mut self, code: StringSlice[mut=False]) -> Bool:
"""Executes the given Python code.

Args:
Expand All @@ -114,7 +114,9 @@ struct Python:

@staticmethod
fn evaluate(
expr: StringRef, file: Bool = False, name: StringRef = "__main__"
expr: StringSlice[mut=False],
file: Bool = False,
name: StaticString = "__main__",
) raises -> PythonObject:
"""Executes the given Python code.

Expand Down Expand Up @@ -202,7 +204,9 @@ struct Python:

# TODO(MSTDL-880): Change this to return `TypedPythonObject["Module"]`
@staticmethod
fn import_module(module: StringRef) raises -> PythonObject:
fn import_module(
module: StringSlice[mut=False],
) raises -> PythonObject:
"""Imports a Python module.

This provides you with a module object you can use just like you would
Expand Down Expand Up @@ -366,7 +370,9 @@ struct Python:
return PythonObject([])

@no_inline
fn __str__(mut self, str_obj: PythonObject) -> StringRef:
fn __str__(
mut self, str_obj: PythonObject
) -> StringSlice[MutableAnyOrigin]:
"""Return a string representing the given Python object.

Args:
Expand Down
Loading
Loading