Skip to content

Commit

Permalink
fix: merging
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Nov 29, 2023
1 parent 573b64d commit 7df3645
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 45 deletions.
13 changes: 10 additions & 3 deletions src/awkward/_do/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
import awkward as ak
from awkward._backends.backend import Backend
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Any, AxisMaybeNone, Literal
from awkward.contents.content import ActionType, Content
from awkward._typing import TYPE_CHECKING, Any, AxisMaybeNone, Literal
from awkward.errors import AxisError
from awkward.forms import form
from awkward.record import Record

np = NumpyMetadata.instance()

if TYPE_CHECKING:
from awkward.contents.content import ActionType, Content
from awkward.record import Record


def recursively_apply(
layout: Content | Record,
Expand All @@ -32,6 +34,9 @@ def recursively_apply(
function_name: str | None = None,
regular_to_jagged=False,
) -> Content | Record | None:
from awkward.contents.content import Content
from awkward.record import Record

if isinstance(layout, Content):
return layout._recursively_apply(
action,
Expand Down Expand Up @@ -201,6 +206,8 @@ def remove_structure(
allow_records: bool = False,
list_to_regular: bool = False,
):
from awkward.record import Record

if isinstance(layout, Record):
return remove_structure(
layout._array[layout._at : layout._at + 1],
Expand Down
22 changes: 6 additions & 16 deletions src/awkward/_do/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@

def is_option(
meta: Meta
) -> TypeGuard[
IndexedOptionMeta | BitMaskedMeta | ByteMaskedMeta | UnmaskedMeta
]:
) -> TypeGuard[IndexedOptionMeta | BitMaskedMeta | ByteMaskedMeta | UnmaskedMeta]:
return meta.is_option


def is_list(
meta: Meta
) -> TypeGuard[RegularMeta | ListOffsetMeta | ListMeta]:
def is_list(meta: Meta) -> TypeGuard[RegularMeta | ListOffsetMeta | ListMeta]:
return meta.is_list


Expand All @@ -59,19 +55,13 @@ def is_indexed(meta: Meta) -> TypeGuard[IndexedOptionMeta, IndexedMeta]:
return meta.is_indexed


class ImplementsTuple(RecordMeta): # Intersection
_fields: None


def is_record_tuple(meta: Meta) -> TypeGuard[ImplementsTuple]:
# FIXME: narrow this to have `is_tuple` be a const True
def is_record_tuple(meta: Meta) -> TypeGuard[RecordMeta]:
return meta.is_record and meta.is_tuple


class ImplementsRecord(RecordMeta):
_fields: list[str]


def is_record_record(meta: Meta) -> TypeGuard[ImplementsRecord]:
# FIXME: narrow this to have `is_tuple` be a const False
def is_record_record(meta: Meta) -> TypeGuard[RecordMeta]:
return meta.is_record and not meta.is_tuple


Expand Down
14 changes: 10 additions & 4 deletions src/awkward/_meta/numpymeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from awkward._do.meta import is_indexed, is_numpy, is_option
from awkward._meta.meta import Meta
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._parameters import type_parameters_equal
from awkward._typing import TYPE_CHECKING, JSONSerializable
from awkward._typing import TYPE_CHECKING, DType, JSONSerializable

np = NumpyMetadata.instance()
if TYPE_CHECKING:
from awkward._meta.regularmeta import RegularMeta

Expand All @@ -17,6 +19,10 @@ class NumpyMeta(Meta):
is_leaf = True
inner_shape: tuple[ShapeItem, ...]

@property
def dtype(self) -> DType:
raise NotImplementedError

def purelist_parameters(self, *keys: str) -> JSONSerializable:
if self._parameters is not None:
for key in keys:
Expand Down Expand Up @@ -104,9 +110,9 @@ def _mergeable_next(self, other: Meta, mergebool: bool) -> bool:

# Default merging (can we cast one to the other)
else:
return self.backend.nplike.can_cast(
self.dtype, other.dtype
) or self.backend.nplike.can_cast(other.dtype, self.dtype)
return np.can_cast(
self.dtype, other.dtype, casting="same_kind"
) or np.can_cast(other.dtype, self.dtype, casting="same_kind")

else:
return False
4 changes: 2 additions & 2 deletions src/awkward/_meta/recordmeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ def _mergeable_next(self, other: Meta, mergebool: bool) -> bool:
return False

elif is_record_record(self) and is_record_record(other):
if set(self._fields) != set(other._fields):
if set(self._fields) != set(other._fields): # type: ignore[arg-type]
return False

for i, field in enumerate(self._fields):
for i, field in enumerate(self._fields): # type: ignore[arg-type]
x = self._contents[i]
y = other.contents[other.field_to_index(field)]
if not x._mergeable_next(y, mergebool):
Expand Down
5 changes: 0 additions & 5 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,11 +653,6 @@ def astype(
assert not isinstance(x, PlaceholderArray)
return x.astype(dtype, copy=copy) # type: ignore[attr-defined]

def can_cast(
self, from_: DTypeLike | ArrayLikeT, to: DTypeLike | ArrayLikeT
) -> bool:
return self._module.can_cast(from_, to, casting="same_kind")

@classmethod
def is_own_array(cls, obj) -> bool:
return cls.is_own_array_type(type(obj))
7 changes: 2 additions & 5 deletions src/awkward/_nplikes/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class NumpyMetadata(PublicSingleton):
datetime_data = staticmethod(numpy.datetime_data)
issubdtype = staticmethod(numpy.issubdtype)

AxisError = numpy.AxisError
AxisError = staticmethod(numpy.AxisError)
can_cast = staticmethod(numpy.can_cast)


if hasattr(numpy, "float16"):
Expand Down Expand Up @@ -537,10 +538,6 @@ def astype(
) -> ArrayLikeT:
...

@abstractmethod
def can_cast(self, from_: DType | ArrayLikeT, to: DType | ArrayLikeT) -> bool:
...

@abstractmethod
def is_c_contiguous(self, x: ArrayLikeT | PlaceholderArray) -> bool:
...
Expand Down
5 changes: 0 additions & 5 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,11 +1512,6 @@ def astype(
x.touch_data()
return TypeTracerArray._new(np.dtype(dtype), x.shape)

def can_cast(
self, from_: DTypeLike | TypeTracerArray, to: DTypeLike | TypeTracerArray
) -> bool:
return numpy.can_cast(from_, to, casting="same_kind")

@classmethod
def is_own_array_type(cls, type_: type) -> bool:
return issubclass(type_, TypeTracerArray)
Expand Down
6 changes: 6 additions & 0 deletions src/awkward/forms/numpyform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,9 @@ def _expected_from_buffers(

def _to_regular_primitive(self) -> RegularForm | NumpyForm:
return self.to_RegularForm()

@property
def dtype(self) -> DType:
from awkward.types.numpytype import primitive_to_dtype

return primitive_to_dtype(self.primitive)
4 changes: 2 additions & 2 deletions src/awkward/operations/ak_strings_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import awkward as ak
from awkward._dispatch import high_level_function
from awkward._do.content import recursively_apply
from awkward._do.content import pad_none, recursively_apply
from awkward._layout import HighLevelContext
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
Expand Down Expand Up @@ -68,7 +68,7 @@ def action(layout, **kwargs):
layout, highlevel=False, behavior=behavior
)
max_length = ak.operations.max(ak.operations.num(layout, behavior=behavior))
regulararray = ak._do.pad_none(layout, max_length, 1)
regulararray = pad_none(layout, max_length, 1)
maskedarray = ak.operations.to_numpy(regulararray, allow_missing=True)
npstrings = maskedarray.data
if maskedarray.mask is not False:
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/operations/ak_values_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import awkward as ak
from awkward._dispatch import high_level_function
from awkward._do.content import numbers_to_type
from awkward._layout import HighLevelContext
from awkward._nplikes.numpy_like import NumpyMetadata

Expand Down Expand Up @@ -73,5 +74,5 @@ def _impl(array, to, including_unknown, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
to_str = ak.types.numpytype.dtype_to_primitive(np.dtype(to))
out = ak._do.numbers_to_type(layout, to_str, including_unknown)
out = numbers_to_type(layout, to_str, including_unknown)
return ctx.wrap(out, highlevel=highlevel)
2 changes: 1 addition & 1 deletion src/awkward/operations/str/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@


def _drop_option_preserving_form(layout, ensure_empty_mask: bool = False):
from awkward._do import recursively_apply
from awkward._do.content import recursively_apply
from awkward.contents import UnmaskedArray, IndexedOptionArray, IndexedArray

def action(_, continuation, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import awkward.forms
from awkward._backends.typetracer import TypeTracerBackend
from awkward._do import touch_data as _touch_data
from awkward._do.content import touch_data as _touch_data
from awkward._layout import HighLevelContext, wrap_layout
from awkward._nplikes.numpy import NumpyMetadata
from awkward._nplikes.placeholder import PlaceholderArray
Expand Down

0 comments on commit 7df3645

Please sign in to comment.