Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for variable referencing #49

Open
alexanderswerdlow opened this issue May 4, 2023 · 12 comments
Open

Support for variable referencing #49

alexanderswerdlow opened this issue May 4, 2023 · 12 comments

Comments

@alexanderswerdlow
Copy link

I've been looking for an alternative to Hydra for config management, specifically one that allows for defining configs in Python, and I stumbled across Tyro which seems like a great library for my use case after some experimentation.

However, one thing that doesn't appear to be possible is referencing a single variable from multiple places in a nested config. As for why this might be needed, it is very common in an ML codebase to require the same parameter in many different places. For example, the number of classification classes might be used in the model construction, visualization, etc.

We might want this value to be dependent on a config group such as the dataset (i.e. each dataset might have a different number of classes). Instead of manually defining each combination of model + dataset, it would be a lot easier to have the model parameters simply reference the dataset parameter, or have them both reference some top-level variable. With Hydra, there is value interpolation that does this.

Since we can define Tyro configs directly in Python, it seems like this could be made much more powerful with support for arbitrary expressions allowing small pieces of logic to be defined in a configuration (e.g., for a specific top-level config we can have a model parameter be 4 * num_classes). Clearly, we could simply make the 4 into a new parameter but there are good reasons we might want it in the config instead.

From what I can tell, this type of variable referencing, even without any expressions, is not currently possible with Tyro.

@brentyi
Copy link
Owner

brentyi commented May 4, 2023

Thanks @alexanderswerdlow!

Just to concretize things a bit: could you suggest some syntax for what you're describing here?

My weak prior on this is that these sort of "driven" config parameters result in unnecessary complexity, and there are often simple workarounds like defining an @property that computes and returns 4 * num_classes or conditioning your model instantiation on the dataset config. But with a concrete example I could be convinced.

@alexanderswerdlow
Copy link
Author

Of course! I think I should clarify that my initial comment mentions two related but distinct features. The first is just plain variable referencing.

For example, I've recently been working with an architecture that can turn some input image into a set of latent vectors or "slots." This is a hyperparameter (num_slots) to the model (e.g., nn.Module) but I also need this parameter to be passed to some separate visualization code that visualizes the effect of each slot. This hyperparameter is dataset dependent, so not only would it be silly to have to define this multiple times (for each module that uses it), but it makes configuring an experiment more difficult. Now, instead of having model configurations and dataset configurations that I can mix-and-match, I have to create each combination manually.

Conditioning the model init on the dataset config does work although there's a couple issues with that:

  1. It's cumbersome to do if you instantiate most of all of your objects directly from your config. You could get around this by passing a config object as a func. call later on and initializing there, but then you're defeating the purpose of instantiating from config.

Hydra's interpolation works for objects as well as primitives so in the example below num_slots could be an entire object (e.g., dataset obj/config) that gets passed to the model. A lot of my code currently relies on this, with a single config dataclass that gets passed around.

  1. Even with this strategy, there are limitations. Namely, it encourages consolidating to a single large config object. If two places in code reference a single variable that you later want to separate, you now have to refactor this. With variable referencing, you just replace the reference with a different default value.

Furthermore, if you do go with the alternative and pass several individual objects, it becomes difficult when you have a messy dependency graph. The model/viz is often dependent on the dataset, the viz is dependent on the model, and some parts of the model are dependent on others. Keeping this modular for experimentation requires referencing each other.

As for syntax, I'd likely need to spend more time thinking about it but my current hydra config looks something like this, using the interpolation syntax:

num_slots: 10
dataset:
    num_slots: ${num_slots}
model:
    encoder_stage:
        output_dim: 256
    decoder_stage:
        input_dim: ${model.encoder_stage.output_dim}
        num_slots: ${num_slots}

Here is an example in Tyro (not one-to-one with the example above to be concise),

@dataclasses.dataclass
class DatasetConfig:
    dataset: Dataset
    num_slots: int
    
@dataclasses.dataclass
class ExperimentConfig:
    model: nn.Module
    dataset: DatasetConfig

main_config = tyro.extras.subcommand_type_from_defaults({
    "small": ExperimentConfig(
            dataset=ExampleDatasetConfig,
            model=ExampleModelConfig,
        ),
})

ExampleDatasetConfig = Annotated[
    DatasetConfig,
    tyro.conf.subcommand(
        name="dataset_b",
        default=DatasetConfig(
            num_slots=8,
            dataset=DatasetB(),
        ),
    ),
]

ExampleModelConfig = Annotated[
    ClassificationModule, # name of an nn.Module
    tyro.conf.subcommand(
        name="dataset_a",
        default=ClassificationModule(
            num_slots=ExperimentConfig.dataset.num_slots, # Somehow reference the encapsulating container
        ),
    ),
]

Now obviously this wouldn't work exactly as-is because there is a circular dependency in definitions here. In my Tyro example above, I go up (so to speak) to experiment config and then back down to dataset, but a single overarching namespace could be simpler to implement (e.g., referencing only works for a set of pre-defined keys, not from any arbitrary container).

The second feature is much smaller both in impact and difficulty but is the ability to perform expressions on variable referencing. You are absolutely right that declaring an @property could work, and if Tyro supports passing functions (e.g., lambda defined in code), you could achieve the same thing.

However, say you have an input image that is downscaled by n (e.g., n=4) and then a separate module (e.g., visualization code) needs to know that downscaled size during initialization. In this case, it'd be a lot cleaner to have as input image_size / n as opposed to passing both of those into the visualization code. The desire for these sorts of expressions comes up naturally in a lot of data pipelines.

Hope that makes sense and I'm happy to explain further! Also totally understand if this is out of scope.

@brentyi
Copy link
Owner

brentyi commented May 5, 2023

Yes, that makes sense!

For variable references, I'm curious about your thoughts on a few options.

One is adapting __post_init__:

import dataclasses

import tyro


@dataclasses.dataclass
class ModelConfig:
    num_slots: int = -1


@dataclasses.dataclass
class TrainConfig:
    num_slots: int
    model: ModelConfig

    def __post_init__(self) -> None:
        if self.model.num_slots == -1:
            self.model.num_slots = self.num_slots


print(tyro.cli(TrainConfig))

In this case --num-slots 3 would set num_slots for both the parent TrainConfig and the inner ModelConfig.

Of course you can add more complex logic in your __post_init__, so this might fulfill your second feature request too.

A potential downside of this is that you won't be able to use an nn.Module directly in your config object as your snippet hints at; config objects need to be mutated after they're instantiated. IMO this is an OK tradeoff since directly dropping in the module has its own drawbacks, like difficulty of serialization.

An alternative option that might be used to circumvent this downside — it's a bit hacky and I wouldn't recommend it, but should work and is in the tests for the 0.5.0 release — is to map both the model's num_slots and the train config's num_slots to --num-slots. To do this we can just omit the prefix from --model.num-slots:

import dataclasses

from typing import Annotated
import tyro


@dataclasses.dataclass
class ModelConfig:
    num_slots: Annotated[int, tyro.conf.arg(prefix_name=False)]


@dataclasses.dataclass
class TrainConfig:
    num_slots: int
    model: ModelConfig

    # edit: next few lines were unintentionally included
    # def __post_init__(self) -> None:
    #     if self.model.num_slots == -1:
    #         self.model.num_slots = self.num_slots


print(tyro.cli(TrainConfig))

Again, --num-slots 3 would set num_slots for both the parent TrainConfig and the inner ModelConfig.

@alexanderswerdlow
Copy link
Author

Sorry for the delay and clever idea!

Before I go on, I assume for the 2nd example, you didn't intend to include the __post_init__. I ran it myself and it works just with the annotation which makes sense.

Some thoughts:

  1. The two options seem to tradeoff customizability for convenience.

The first option allows arbitrary configuration with expressions but the syntax is a little unwieldy. The second option on the other hand (from what I can tell) essentially gives you a single global namespace (simply without a prefix) to perform referencing.

Most use cases are probably fine with a global namespace but I think a core issue remains (for my use case at least).

  1. The bigger issue I see here (for my use case at least) is that either approach couples the configuration interface with the config itself, making hierarchical and modular configuration difficult.

In other words, say I have an MLP class (dataclass config or actual class); I might want different experiments to use that same MLP in different ways (likely multiple times within the same experiment). That rules out the 2nd approach, but even the 1st approach is difficult. From what I can tell, the user would need to make two distinct higher-level configs (to allow for a different __post_init__).

Now I certainly see that this might not be an issue for many and this approach might make a lot of sense for them! I happen to need things particularly modular for experimentation, which is also why I gravitate towards instantiating things directly. Doing so removes an intermediate step that needs to be constantly updated.

@brentyi
Copy link
Owner

brentyi commented May 9, 2023

Thanks for clarifying! Two followup questions:

The bigger issue I see here (for my use case at least) is that either approach couples the configuration interface with the config itself, making hierarchical and modular configuration difficult.

(1) So to re-state: it would be nice to be able to define a config schema via dataclasses, and then define relationships between values in it when you instantiate configs?

I happen to need things particularly modular for experimentation, which is also why I gravitate towards instantiating things directly. Doing so removes an intermediate step that needs to be constantly updated.

(2) I'm not totally following what "instantiating things directly" is referring to. Is this referencing the __post_init__() as an intermediate step?


To try and resolve (1), what about creating some subcommands? When you instantiate each subcommand the default values for each field can be computed from whatever logic you want.

import dataclasses
from typing import Dict

import tyro


@dataclasses.dataclass
class ModelConfig:
    num_slots: int


@dataclasses.dataclass
class TrainConfig:
    num_slots: int
    model: ModelConfig


subcommands: Dict[str, TrainConfig] = {}

# First experiment.
subcommands["exp1"] = TrainConfig(
    num_slots=2,
    model=ModelConfig(num_slots=2),
)

# Second experiment.
num_slots = 4
subcommands["exp2"] = TrainConfig(
    num_slots=num_slots,
    model=ModelConfig(num_slots=num_slots * 2),
)

config = tyro.cli(
    tyro.extras.subcommand_type_from_defaults(subcommands)
)
print(config)

Of course since everything is Python, you can also generate this dictionary programatically. Perhaps the downside here is that python example.py exp2 --num-slots N now won't also update model.num_slots? Is that a dealbreaker?

In general I think there's still a disconnect where I don't fully follow what limitation makes modularity/hierarchy harder than in Hydra. When I read the specializing configs docs in Hydra nothing stands out to me — both the ${dataset}_${model} pattern and the CIFAR vs ImageNet num_layers default seem easy enough to replicate in Python. If you have any links to examples in-the-wild (your own or from others) that you'd want to replicate it might be helpful for my understanding.

@brentyi
Copy link
Owner

brentyi commented May 9, 2023

As an FYI, I'm also going to raise an error in this case:

import dataclasses

from typing import Annotated
import tyro


@dataclasses.dataclass
class ModelConfig:
    num_slots: Annotated[int, tyro.conf.arg(prefix_name=False)]


@dataclasses.dataclass
class TrainConfig:
    num_slots: int
    model: ModelConfig

(just feels too hacky)

@mirceamironenco
Copy link
Contributor

mirceamironenco commented Aug 17, 2024

Hi! Just wanted to ask if there is a canonical/recommended way of doing this. My use case is quite simple e.g.:

@dataclass
class MambaCfg:
    embed_dim: int

    # Other params
    y: float
    z: str


@dataclass
class AttentionCfg:
    embed_dim: int

    # Other params
    g: float
    f: str


@dataclass
class BlockCfg:
    embed_dim: int

    # e.g. Union of types with 'embed_dim' attribute
    # can we set it automatically from the blockcfg embed_dim?
    layer: AttentionCfg | MambaCfg

In this case if I tyro.cli(BlockCfg) it would require that I specify embed_dim twice.
I was wondering if there is a recommended/quick way of constructing the dependency, without having a full-fledged interpolation mechanism as in OmegaConf, etc. A simple/hacky solution would be of course to just customize the constructor and build some partial type e.g.:

T = TypeVar("T")

# Just an example; would need to be updated to preserve helptext, etc.
def create_partial_type(cls: Type[T], committed_param: str) -> Type:
    class_fields = [
        (f.name, f.type, f) for f in fields(cls) if f.name != committed_param
    ]

    def __call__(self, committed_value: Any) -> T:
        all_args = {**asdict(self), committed_param: committed_value}
        return cls(**all_args)

    partial_cls = make_dataclass(
        f"Partial{cls.__name__}",
        fields=class_fields,
        namespace={"__call__": __call__},
    )

    return partial_cls
    
# for the previous use-case:
PartialLayerCfg: TypeAlias = Union[
    *(
        Annotated[obj, tyro.conf.arg(constructor=create_partial_type(obj, "embed_dim"))]
        for obj in (AttentionCfg, MambaCfg)
    )
]

class BlockCfg:
    def __init__(self, embed_dim: int, layer: PartialLayerCfg) -> None:
        self.embed_dim = embed_dim
        self.layer = layer(embed_dim)

edit: Another approach I was considering was allowing the outer-most callable/type have its namespace be accessible, this way one could specify something like --embed_dim INT layer:attention-cfg --layer.embed_dim=embed_dim (for the initial example), possibly with some ability to specify a default in a similar fashion.

@brentyi
Copy link
Owner

brentyi commented Aug 18, 2024

Hi @mirceamironenco! Unfortunately I don't have a tyro-specific recommendation. I've thought about APIs in the direction of variable interpolation a few times but haven't come up with anything I'm happy with.

Usually in these situations I just think about how I would structure things if I were building a pure Python API, for example asking a downstream user to instantiate these config objects in Jupyter notebook, and then a tyro solution falls out of that. This basically reduces to one of:

  1. Implement the code in a way that removes the dependency / redundancy
  2. If too difficult, live with the redundancy and add some assert statements

For the first option, would it be possible to remove embed_dim from MambaCfg and AttentionCfg? And instead you can either pass the (embed_dim, mambda_cfg) as arguments to your Mamba constructor, or (embed_dim, attention_cfg) as arguments to your Attention constructor? Or, depending on your project goals it could be fine to just pass the whole BlockCfg to every constructor. Here's a fancier example using Python 3.12 generics syntax:

from dataclasses import dataclass
from torch import  nn
import tyro

@dataclass
class MambaCfg:
    # Other params
    y: float
    z: str

@dataclass
class AttentionCfg:
    # Other params
    g: float
    f: str

@dataclass
class BlockCfg[LayerCfg: (MambaCfg, AttentionCfg)]:
    embed_dim: int
    layer: LayerCfg

class Mamba(nn.Module):
    def __init__(self, cfg: BlockCfg[MambaCfg]):  # Takes BlockCfg with isinstance(cfg.layer, MambaCfg)
        print(cfg.embed_dim)
        print(cfg.layer.y)
        ...

tyro.cli(BlockCfg)

For syntax that looks like --embed_dim INT layer:attention-cfg --layer.embed_dim=embed_dim, could this be implemented outside of tyro? For example: by preprocessing sys.argv and passing in tyro.cli(..., args=processed_argv).

@mirceamironenco
Copy link
Contributor

mirceamironenco commented Aug 26, 2024

Hi @brentyi, thank you for the suggestions! I agree with your thought process and 'promoting' the shared parameters is also the solution I would opt for. I think the main frustration/use-case for variable referencing came from (my understanding of) the type system (and its limitations), i.e. for a setting such as:

from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import Optional

import torch.nn as nn


@dataclass
class AttentionCfg:
    num_heads: int = 8
    attn_drop: float = 0.0
    window_size: Optional[int] = None

    def build(self, embed_dim: int) -> Attention:
        return Attention(embed_dim, **asdict(self))


class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        *,
        num_heads: int = 8,
        attn_drop: float = 0.0,
        window_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        ...

Suppose the idea here is that the key-word arguments of Attention.__init__ are to match the attributes of the AttentionCfg dataclass (and I don't want to modify the signature to take in a Cfg instance). It isn't clear to me how force a static type checker to throw an error if e.g. another kw arg is added:

class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        *,
        foo: float,
        num_heads: int = 8,
        attn_drop: float = 0.0,
        window_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        ....

Now build would throw an error. Of course, runtime checks/tests will immediately detect this, however I was hoping there would be a way (either with ParamSpec or a dynamic TypedDict) to enforce the signature and dataclass attributes to be consistent, and any inconsistency to be highlighted by mypy/pyright/etc.

@brentyi
Copy link
Owner

brentyi commented Aug 28, 2024

Thanks for clarifying! Yeah, I can see why it's tough if you don't want to take the config instance as input to your constructor.

I'm not fully following why variable referencing would make type safety easier, though. If you had a pattern without the extra embed_dim argument:

@dataclass
class AttentionCfg:
    num_heads: int = 8
    attn_drop: float = 0.0
    window_size: Optional[int] = None

    def build(self) -> Attention:
        return Attention(**asdict(self))


class Attention(nn.Module):
    def __init__(
        self,
        *,
        num_heads: int = 8,
        attn_drop: float = 0.0,
        window_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        ...

wouldn't the same signature/dataclass attribute consistency problem still exist?

@mirceamironenco
Copy link
Contributor

mirceamironenco commented Aug 28, 2024

You are right, that's unclear form the previous message.

With variable referencing I would drop the {Layer}Cfg classes altogether, and expose the layer constructors to tyro.cli. The idea is to avoid a scenario where the user has to specify the embed_dim twice, while still allowing them to see all of the other options possible for a layer. Something like:

@dataclass
class BlockCfg:
    embed_dim: int

    # No layer cfgs, just the nn.Module subclasses themselves.
    # Presumably we would have some syntax that's missing here which would
    # realize the referencing, like Annotated[*_mixer, ...=BlockCfg.embed_dim]
    sequence_mixer: Attention | Mamba
    state_mixer: Mlp | MoE

Assume all layers have an initial positional argument embed_dim. With referencing running python foo.py --embed_dim 128 sequence-mixer:attention state-mixer:mlp -h will highlight all other layer options, and it should be clear the embed_dim is set from the block attribute.

Paying the DRY price and using the layer + cfg-with-same-kwargs option is fine as the user has a nicer experience, hence my question about type hinting borrowed signatures.

In any case, thanks for taking a look at this! Looking a bit at the other referencing implementations it seems to add quite a layer of complexity, since you have to topologically sort the dependency graph and instantiate things in order, which might lead to a lot of corner cases given tyros' other (nice!) features (but you would know better).

@brentyi
Copy link
Owner

brentyi commented Aug 28, 2024

Makes sense, thanks for clarifying! It does seem nice, but yeah the naive implementations I can think of also all seem pretty complex in terms of both implementation and user experience. I think the use case of wanting to avoid an explicit config object is also lower priority to myself personally; the overhead of the extra class is annoying but that feels outweighed by the usefulness of being able to instantiate/save/restore the config object independently of the module itself.

That said if any new ideas for implementation/syntax occur to you please feel free to share! I'd be interested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants