Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,7 @@ def get_mypy_config(
mypyc_sources = all_sources

if compiler_options.separate:
mypyc_sources = [
src for src in mypyc_sources if src.path and not src.path.endswith("__init__.py")
]
mypyc_sources = [src for src in mypyc_sources if src.path]

if not mypyc_sources:
return mypyc_sources, all_sources, options
Expand Down
33 changes: 31 additions & 2 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from mypyc.codegen.literals import Literals
from mypyc.common import (
EXT_SUFFIX,
IS_FREE_THREADED,
MODULE_PREFIX,
PREFIX,
Expand Down Expand Up @@ -1286,11 +1287,39 @@ def emit_module_init_func(
f"if (unlikely({module_static} == NULL))",
" goto fail;",
)

emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
if self.group_name:
shared_lib_mod_name = shared_lib_name(self.group_name)
emitter.emit_line("PyObject *mod_dict = PyImport_GetModuleDict();")
emitter.emit_line(
f'PyObject *shared_lib = PyDict_GetItemString(mod_dict, "{shared_lib_mod_name}");'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyDict_GetItemStringRef would be a bit more correct in free-threading builds.

)
emitter.emit_line("if (shared_lib == NULL) goto fail;")
emitter.emit_line(
'PyObject *shared_lib_file = PyObject_GetAttrString(shared_lib, "__file__");'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PyDict_GetItemStringRef instead, as this returns a borrowed reference, and we decref it below.

)
emitter.emit_line("if (shared_lib_file == NULL) goto fail;")
else:
emitter.emit_line(
f'PyObject *shared_lib_file = PyUnicode_FromString("{module_name + EXT_SUFFIX}");'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No NULL check for shared_lib_file (it's only checked in the branch above).

)
emitter.emit_line(f'PyObject *ext_suffix = PyUnicode_FromString("{EXT_SUFFIX}");')
emitter.emit_line("if (ext_suffix == NULL) CPyError_OutOfMemory();")
is_pkg = int(self.source_paths[module_name].endswith("__init__.py"))
emitter.emit_line(f"Py_ssize_t is_pkg = {is_pkg};")

emitter.emit_line(
f"int rv = CPyImport_SetDunderAttrs({module_static}, modname, shared_lib_file, ext_suffix, is_pkg);"
)
emitter.emit_line("Py_DECREF(ext_suffix);")
emitter.emit_line("Py_DECREF(shared_lib_file);")
emitter.emit_line("if (rv < 0) goto fail;")

# Register in sys.modules early so that circular imports via
# CPyImport_ImportNative can detect that this module is already
# being initialized and avoid re-executing the module body.
emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
emitter.emit_line(
f"if (PyObject_SetItem(PyImport_GetModuleDict(), modname, {module_static}) < 0)"
)
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
CPyModule **module_static,
PyObject *shared_lib_file, PyObject *ext_suffix,
Py_ssize_t is_package);
int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
PyObject *ext_suffix, Py_ssize_t is_package);

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
PyObject *func);
Expand Down
109 changes: 67 additions & 42 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,47 @@ static int CPyImport_InitSpecClasses(void) {
return 0;
}

// Set __package__ before executing the module body so it is available
// during module initialization. For a package, __package__ is the module
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
// top-level non-package module, it is "".
static int CPyImport_SetModulePackage(PyObject *modobj, PyObject *module_name,
Py_ssize_t is_package) {
PyObject *pkg = NULL;
int rc = PyObject_GetOptionalAttrString(modobj, "__package__", &pkg);
if (rc < 0) {
return -1;
}
if (pkg != NULL && pkg != Py_None) {
Py_DECREF(pkg);
return 0;
}
Py_XDECREF(pkg);

PyObject *package_name = NULL;
if (is_package) {
package_name = module_name;
Py_INCREF(package_name);
} else {
Py_ssize_t name_len = PyUnicode_GetLength(module_name);
if (name_len < 0) {
return -1;
}
Py_ssize_t dot = PyUnicode_FindChar(module_name, '.', 0, name_len, -1);
if (dot >= 0) {
package_name = PyUnicode_Substring(module_name, 0, dot);
} else {
package_name = PyUnicode_FromString("");
}
}
if (package_name == NULL) {
return -1;
}
rc = PyObject_SetAttrString(modobj, "__package__", package_name);
Py_DECREF(package_name);
return rc;
}

// Derive and set __file__ on modobj from the shared library path, module name,
// and extension suffix. Returns 0 on success, -1 on error.
static int CPyImport_SetModuleFile(PyObject *modobj, PyObject *module_name,
Expand Down Expand Up @@ -1509,47 +1550,7 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
goto fail;
}

// Set __package__ before executing the module body so it is available
// during module initialization. For a package, __package__ is the module
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
// top-level non-package module, it is "".
{
PyObject *pkg = NULL;
if (PyObject_GetOptionalAttrString(modobj, "__package__", &pkg) < 0) {
goto fail;
}
if (pkg == NULL || pkg == Py_None) {
Py_XDECREF(pkg);
PyObject *package_name;
if (is_package) {
package_name = module_name;
Py_INCREF(package_name);
} else if (dot >= 0) {
package_name = PyUnicode_Substring(module_name, 0, dot);
} else {
package_name = PyUnicode_FromString("");
if (package_name == NULL) {
CPyError_OutOfMemory();
}
}
if (PyObject_SetAttrString(modobj, "__package__", package_name) < 0) {
Py_DECREF(package_name);
goto fail;
}
Py_DECREF(package_name);
} else {
Py_DECREF(pkg);
}
}

if (CPyImport_SetModuleFile(modobj, module_name, shared_lib_file, ext_suffix,
is_package) < 0) {
goto fail;
}
if (is_package && CPyImport_SetModulePath(modobj) < 0) {
goto fail;
}
if (CPyImport_SetModuleSpec(modobj, module_name, is_package) < 0) {
if (CPyImport_SetDunderAttrs(modobj, module_name, shared_lib_file, ext_suffix, is_package) < 0) {
goto fail;
}

Expand Down Expand Up @@ -1577,10 +1578,34 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
PyErr_Restore(exc_type, exc_val, exc_tb);
Py_XDECREF(parent_module);
Py_XDECREF(child_name);
Py_DECREF(modobj);
Py_CLEAR(*module_static);
return NULL;
}

int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
PyObject *ext_suffix, Py_ssize_t is_package)
{
int res = CPyImport_SetModulePackage(module, module_name, is_package);
if (res < 0) {
return res;
}

res = CPyImport_SetModuleFile(module, module_name, shared_lib_file, ext_suffix,
is_package);
if (res < 0) {
return res;
}

if (is_package) {
res = CPyImport_SetModulePath(module);
if (res < 0) {
return res;
}
}

return CPyImport_SetModuleSpec(module, module_name, is_package);
}

#if CPY_3_14_FEATURES

#include "internal/pycore_object.h"
Expand Down
32 changes: 32 additions & 0 deletions mypyc/test-data/run-multimodule.test
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,38 @@ globals()['A'] = None
[file driver.py]
import other_main

[case testNonNativeImportInPackageFile]
# The import is really non-native only in separate compilation mode where __init__.py and
# other_cache.py are in different libraries and the import uses the standard Python procedure.
# Python imports are resolved using __path__ and __spec__ from the package file so this checks
# that they are set up correctly.
[file other/__init__.py]
from other.other_cache import Cache

x = 1
[file other/other_cache.py]
class Cache:
pass

[file driver.py]
import other

[case testRelativeImportInPackageFile]
# Relative imports from a compiled package __init__ depend on package metadata being
# available while the package module body is executing.
[file other/__init__.py]
assert __package__ == "other"
from .other_cache import Cache

x = 1
[file other/other_cache.py]
class Cache:
pass

[file driver.py]
import other
assert other.Cache.__name__ == "Cache"

[case testMultiModuleSameNames]
# Use same names in both modules
import other
Expand Down
Loading