diff --git a/tests/conftest.py b/tests/conftest.py index 89bec2c..d04cfba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,6 +98,5 @@ def game(monkeypatch: pytest.MonkeyPatch) -> Game: # Add a single water tile for code coverage game = Game() game.setup(terrain=Terrain(water=[Vec2(210, 210)])) - game.grid.mines = {} - game.grid.factories = {} + game.grid.buildings = {} return game diff --git a/tests/test_game.py b/tests/test_game.py index 3b62e31..1fd8951 100644 --- a/tests/test_game.py +++ b/tests/test_game.py @@ -35,7 +35,7 @@ def test_draw(game: Game): ) game._create_train(*game.grid.station_from_position.values()) # Mainly for code coverage - game.trains[0].wagons[0].cargo_count = 1 + game.trains[0].wagons[0].cargo_count[CargoType.IRON] == 1 game.on_draw() @@ -644,7 +644,7 @@ def test_iron_is_regularly_added_to_mines(game: Game): game.on_update(1 / 60) - assert game.grid.mines[Vec2(1, 1)].cargo_count == 1 + assert game.grid.buildings[Vec2(1, 1)].cargo_count[CargoType.IRON] == 1 assert ( len(game.drawer.cargo_shape_element_list) == 2 ) # One for the interior, one for the frame @@ -700,20 +700,20 @@ def test_train_picks_up_iron_from_mine(game: Game): """, ) game._create_train(*game.grid.station_from_position.values()) - mine = game.grid.mines[Vec2(1, 1)] + mine = game.grid.buildings[Vec2(1, 1)] train = game.trains[0] - mine.add_cargo() + mine.try_create_cargo() train.x = 1 train.target_x = 1 train._target_station = game.grid.station_from_position[Vec2(1, 0)] - assert mine.cargo_count == 1 - assert train.wagons[0].cargo_count == 0 + assert mine.cargo_count[CargoType.IRON] == 1 + assert train.wagons[0].cargo_count[CargoType.IRON] == 0 - while check(train.wagons[0].cargo_count == 0): + while check(train.wagons[0].cargo_count[CargoType.IRON] == 0): game.on_update(1 / 60) - assert mine.cargo_count == 0 - assert train.wagons[0].cargo_count == 1 + assert mine.cargo_count[CargoType.IRON] == 0 + assert train.wagons[0].cargo_count[CargoType.IRON] == 1 def test_train_delivers_iron_to_factory_gives_score(game: Game): @@ -727,15 +727,15 @@ def test_train_delivers_iron_to_factory_gives_score(game: Game): ) game._create_train(*game.grid.station_from_position.values()) train = game.trains[0] - train.wagons[0].cargo_count = 1 + train.wagons[0].cargo_count[CargoType.IRON] = 1 train.x = 3 train.target_x = 3 train._target_station = game.grid.station_from_position[Vec2(3, 0)] - while check(train.wagons[0].cargo_count): + while check(train.wagons[0].cargo_count[CargoType.IRON]): game.on_update(1 / 60) - assert train.wagons[0].cargo_count == 0 + assert train.wagons[0].cargo_count[CargoType.IRON] == 0 assert game.player.score == 1 diff --git a/tests/test_tests_util.py b/tests/test_tests_util.py index 2e4b460..7793df3 100644 --- a/tests/test_tests_util.py +++ b/tests/test_tests_util.py @@ -1,5 +1,12 @@ from pyglet.math import Vec2 -from trainfinity2.model import Mine, Rail, CargoType, SignalColor, Station, SteelWorks +from trainfinity2.model import ( + IronMine, + Rail, + SignalColor, + Station, + Building, + SteelWorks, +) from trainfinity2.__main__ import Game from tests.util import create_objects @@ -84,8 +91,10 @@ def test_create_objects(game: Game): .-S-.hS-.""", ) - assert game.grid.mines == {Vec2(1, 1): Mine(Vec2(1, 1), cargo_type=CargoType.IRON)} - assert game.grid.factories == {Vec2(3, 1): SteelWorks(Vec2(3, 1))} + assert game.grid.buildings == { + Vec2(1, 1): IronMine(Vec2(1, 1)), + Vec2(3, 1): SteelWorks(Vec2(3, 1)), + } assert game.grid.station_from_position == { Vec2(3, 0): Station((Vec2(3, 0),)), Vec2(1, 0): Station((Vec2(1, 0),)), @@ -105,7 +114,7 @@ def test_create_with_offset(game: Game): . .""", ) - assert game.grid.mines == {Vec2(1, 1): Mine(Vec2(1, 1), cargo_type=CargoType.IRON)} + assert game.grid.buildings == {Vec2(1, 1): IronMine(Vec2(1, 1))} def test_create_with_offset_and_last_line(game: Game): @@ -117,4 +126,4 @@ def test_create_with_offset_and_last_line(game: Game): . . """, ) - assert game.grid.mines == {Vec2(1, 1): Mine(Vec2(1, 1), cargo_type=CargoType.IRON)} + assert game.grid.buildings == {Vec2(1, 1): IronMine(Vec2(1, 1))} diff --git a/trainfinity2/game.py b/trainfinity2/game.py index 0323017..02bd773 100644 --- a/trainfinity2/game.py +++ b/trainfinity2/game.py @@ -97,10 +97,8 @@ def setup(self, terrain: Terrain): def on_update(self, delta_time): self.cargo_counter += delta_time if self.cargo_counter > SECONDS_BETWEEN_CARGO_CREATION: - for mine in self.grid.mines.values(): - self.drawer.handle_events([mine.add_cargo()]) - for factory in self.grid.factories.values(): - self.drawer.handle_events([factory.transform_cargo()]) + for building in self.grid.buildings.values(): + self.drawer.handle_events([building.try_create_cargo()]) self.cargo_counter = 0.0 self._update_gui_figures(delta_time) diff --git a/trainfinity2/graphics/drawer.py b/trainfinity2/graphics/drawer.py index 041f69e..b675f14 100644 --- a/trainfinity2/graphics/drawer.py +++ b/trainfinity2/graphics/drawer.py @@ -27,15 +27,17 @@ StationBeingBuiltEvent, ) from ..model import ( - Factory, + Building, CargoType, CargoAddedEvent, CargoRemovedEvent, - Mine, + CoalMine, + IronMine, Rail, Signal, SignalColor, Station, + SteelWorks, ) from ..events import CreateEvent, DestroyEvent, Event, NullEvent from ..train import Train @@ -109,10 +111,12 @@ def __init__(self): def handle_events(self, events: Iterable[Event]): for event in events: match event: - case CreateEvent(Mine() as mine): + case CreateEvent(CoalMine() as mine): self._create_mine(mine) - case CreateEvent(Factory() as factory): - self._create_factory(factory) + case CreateEvent(IronMine() as mine): + self._create_mine(mine) + case CreateEvent(SteelWorks() as steelworks): + self._create_steelworks(steelworks) case CreateEvent(Station() as station): self._create_station(station) case CreateEvent(Rail() as rail): @@ -186,18 +190,18 @@ def create_grid(self, grid: Grid): ) ) - def _create_factory(self, factory: Factory): + def _create_steelworks(self, steelworks: SteelWorks): sprite = arcade.Sprite( "images/factory.png", 0.75, - center_x=factory.position.x * GRID_BOX_SIZE_PIXELS + center_x=steelworks.position.x * GRID_BOX_SIZE_PIXELS + GRID_BOX_SIZE_PIXELS / 2, - center_y=factory.position.y * GRID_BOX_SIZE_PIXELS + center_y=steelworks.position.y * GRID_BOX_SIZE_PIXELS + GRID_BOX_SIZE_PIXELS / 2, ) - self._add_sprite(sprite, factory) + self._add_sprite(sprite, steelworks) - def _create_mine(self, mine: Mine): + def _create_mine(self, mine: IronMine | CoalMine): sprite = arcade.Sprite( "images/mine.png", 0.75, diff --git a/trainfinity2/graphics/train_drawer.py b/trainfinity2/graphics/train_drawer.py index 6590b40..7c634e2 100644 --- a/trainfinity2/graphics/train_drawer.py +++ b/trainfinity2/graphics/train_drawer.py @@ -115,8 +115,8 @@ def _draw_wagon(self, wagon: Wagon): color=color.REDWOOD, tilt_angle=wagon.angle, ) - if wagon.cargo_count: - for shape in get_cargo_shape( - x, y, wagon.cargo_type, tilt_angle=wagon.angle - ): - shape.draw() + for cargo_type in wagon.cargo_count: + if wagon.cargo_count[cargo_type]: + for shape in get_cargo_shape(x, y, cargo_type, tilt_angle=wagon.angle): + shape.draw() + break diff --git a/trainfinity2/grid.py b/trainfinity2/grid.py index 070926f..57c59a7 100644 --- a/trainfinity2/grid.py +++ b/trainfinity2/grid.py @@ -10,12 +10,14 @@ from .gui import Mode from .model import ( - Factory, - Mine, + Building, + CoalMine, + IronMine, Rail, CargoType, Signal, Station, + Building, SteelWorks, Water, ) @@ -81,8 +83,7 @@ def __init__(self, terrain: Terrain, signal_controller: SignalController) -> Non self._signal_controller = signal_controller self.water: dict[Vec2, Water] = {} - self.mines: dict[Vec2, Mine] = {} - self.factories: dict[Vec2, Factory] = {} + self.buildings: dict[Vec2, Building] = {} self.station_from_position: dict[Vec2, Station] = {} self.signals: dict[tuple[Vec2, Rail], Signal] = {} self.rails_being_built: set[Rail] = set() @@ -102,11 +103,10 @@ def _create_terrain(self, terrain: Terrain): for position in terrain.water: self.water[position] = Water(position) - def _get_random_position_to_build_mine_or_factory(self) -> Vec2: + def _get_random_position_to_build_building(self) -> Vec2: illegal_positions = ( self.water.keys() - | self.mines.keys() - | self.factories.keys() + | self.buildings.keys() | self.station_from_position.keys() | {position for rail in self.rails for position in rail.positions} ) @@ -116,29 +116,27 @@ def _get_random_position_to_build_mine_or_factory(self) -> Vec2: return position def create_mine(self, position: Vec2, cargo: CargoType) -> CreateEvent: - mine = Mine(position, cargo_type=cargo) - self.mines[position] = mine + mine = {CargoType.COAL: CoalMine(position), CargoType.IRON: IronMine(position)}[ + cargo + ] + self.buildings[position] = mine return CreateEvent(mine) def _create_mine_in_random_unoccupied_location( self, cargo: CargoType ) -> CreateEvent: - return self.create_mine( - self._get_random_position_to_build_mine_or_factory(), cargo - ) + return self.create_mine(self._get_random_position_to_build_building(), cargo) def _create_mines(self) -> list[Event]: return [self._create_mine_in_random_unoccupied_location(CargoType.IRON)] def _create_factory(self, position: Vec2) -> CreateEvent: factory = SteelWorks(position) - self.factories[position] = factory + self.buildings[position] = factory return CreateEvent(factory) def _create_factory_in_random_unoccupied_location(self) -> CreateEvent: - return self._create_factory( - self._get_random_position_to_build_mine_or_factory() - ) + return self._create_factory(self._get_random_position_to_build_building()) def _create_factories(self) -> list[Event]: return [self._create_factory_in_random_unoccupied_location()] @@ -186,9 +184,7 @@ def _is_inside_station_in_wrong_direction(self, rail: Rail): return False def _mark_illegal_rail(self, rails: Iterable[Rail]) -> set[Rail]: - illegal_positions = ( - self.water.keys() | self.mines.keys() | self.factories.keys() - ) + illegal_positions = self.water.keys() | self.buildings.keys() return { ( rail.to_illegal() @@ -201,9 +197,7 @@ def _mark_illegal_rail(self, rails: Iterable[Rail]) -> set[Rail]: } def _illegal_station_positions(self, station: Station) -> set[Vec2]: - if not any( - self._adjacent_mine_or_factory(position) for position in station.positions - ): + if not self.adjacent_buildings(station.positions): return set(station.positions) overlapping_positions_with_rail_in_wrong_direction = { @@ -217,8 +211,7 @@ def _illegal_station_positions(self, station: Station) -> set[Vec2]: } illegal_positions = ( self.water.keys() - | self.mines.keys() - | self.factories.keys() + | self.buildings.keys() | self.station_from_position.keys() ) return ( @@ -322,31 +315,14 @@ def _is_adjacent(self, position1: Vec2, position2: Vec2): 3 / 4 < dy < 5 / 4 and dx < 1 / 4 ) - def adjacent_mines(self, positions: Iterable[Vec2]) -> list[Mine]: - return [ - mine - for mine in self.mines.values() - for position in positions - if self._is_adjacent(position, mine.position) - ] - - def adjacent_factories(self, positions: Iterable[Vec2]) -> list[Factory]: + def adjacent_buildings(self, positions: Iterable[Vec2]) -> list[Building]: return [ - factory - for factory in self.factories.values() + building + for building in self.buildings.values() for position in positions - if self._is_adjacent(position, factory.position) + if self._is_adjacent(position, building.position) ] - def _adjacent_mine_or_factory(self, position: Vec2) -> Mine | Factory | None: - for mine in self.mines.values(): - if self._is_adjacent(position, mine.position): - return mine - for factory in self.factories.values(): - if self._is_adjacent(position, factory.position): - return factory - return None - def _create_station(self, station: Station) -> CreateEvent: """Creates a station in a location. Must be next to a mine or a factory, or it raises AssertionError. @@ -403,7 +379,7 @@ def toggle_signals_at_click_position( return self.toggle_signals_at_grid_position(x, y) def toggle_signals_at_grid_position(self, x: float, y: float) -> Sequence[Event]: - events = [] + events: list[Event] = [] if rail := self._closest_rail(x, y): for position in rail.positions: signal = self.signals.get((position, rail)) diff --git a/trainfinity2/model.py b/trainfinity2/model.py index a7dc609..66c01a6 100644 --- a/trainfinity2/model.py +++ b/trainfinity2/model.py @@ -49,7 +49,7 @@ class Water: position: Vec2 -MAX_CARGO_AT_MINE = 8 +MAX_CARGO_AT_BUILDING = 8 class CargoType(Enum): @@ -71,52 +71,60 @@ class CargoRemovedEvent(Event): @dataclass -class Factory(ABC): - position: Vec2 - cargo_count: dict[CargoType, int] = field(default_factory=lambda: defaultdict(int)) - accepts: set[CargoType] = field(init=False) +class Recipe: + output: CargoType + input: set[CargoType] = field(default_factory=set) - def add_cargo(self, type: CargoType): - self.cargo_count[type] += 1 - @abstractmethod - def transform_cargo(self): - raise NotImplementedError +@dataclass +class Building(ABC): + position: Vec2 + cargo_count: dict[CargoType, int] = field(default_factory=lambda: defaultdict(int)) + recipe: Recipe = field(init=False) + @property + def accepts(self) -> set[CargoType]: + return self.recipe.input -@dataclass -class SteelWorks(Factory): - def __post_init__(self): - self.accepts = {CargoType.COAL, CargoType.IRON} + @property + def produces(self) -> set[CargoType]: + return {self.recipe.output} - def transform_cargo(self) -> Event: + def try_create_cargo(self) -> Event: if ( - self.cargo_count[CargoType.COAL] >= 1 - and self.cargo_count[CargoType.IRON] >= 1 + all(self.cargo_count[input_cargo] for input_cargo in self.recipe.input) + and self.cargo_count[self.recipe.output] <= MAX_CARGO_AT_BUILDING ): - self.cargo_count[CargoType.IRON] -= 1 - self.cargo_count[CargoType.COAL] -= 1 - self.cargo_count[CargoType.STEEL] += 1 - return CargoAddedEvent(self.position, CargoType.STEEL) + for input_cargo in self.recipe.input: + self.cargo_count[input_cargo] -= 1 + self.cargo_count[self.recipe.output] += 1 + return CargoAddedEvent(self.position, self.recipe.output) return NullEvent() + def remove_cargo(self, type: CargoType, amount: int) -> CargoRemovedEvent: + assert self.cargo_count[type] >= amount + self.cargo_count[type] -= amount + return CargoRemovedEvent(self.position, amount) + @dataclass -class Mine: - position: Vec2 - cargo_type: CargoType - cargo_count: int = 0 +class SteelWorks(Building): + def __post_init__(self): + self.recipe = Recipe( + input={CargoType.COAL, CargoType.IRON}, output=CargoType.STEEL + ) - def add_cargo(self) -> Event: - if self.cargo_count < MAX_CARGO_AT_MINE: - self.cargo_count += 1 - return CargoAddedEvent(self.position, self.cargo_type) - return NullEvent() - def remove_cargo(self, amount) -> CargoRemovedEvent: - amount_taken = amount if amount <= self.cargo_count else self.cargo_count - self.cargo_count -= amount_taken - return CargoRemovedEvent(self.position, amount_taken) +@dataclass +class CoalMine(Building): + def __post_init__(self): + self.recipe = Recipe(output=CargoType.COAL) + + +@dataclass +class IronMine(Building): + def __post_init__(self): + self.recipe = Recipe(output=CargoType.IRON) @dataclass(frozen=True) @@ -156,9 +164,6 @@ def internal_and_external_rail(self) -> set[Rail]: return {rail1, rail2}.union(self.internal_rail) -Building = Mine | Factory | Station - - def get_level_scores() -> list[int]: """ The number of points required to reach a certain level. diff --git a/trainfinity2/train.py b/trainfinity2/train.py index 524358d..fa00547 100644 --- a/trainfinity2/train.py +++ b/trainfinity2/train.py @@ -10,7 +10,7 @@ from .grid import Grid -from .model import Factory, Player, Rail, CargoType, Station +from .model import Building, Player, Rail, CargoType, Station from .wagon import Wagon from .route_finder import find_route, has_reached_end_of_target_station from .signal_controller import SignalController @@ -173,49 +173,53 @@ def _stop_at_station(self, station: Station): # Check factories before mines, or a the iron will # instantly be transported to the factory self.speed = 0 - is_finished = True - if ( - factories := self.grid.adjacent_factories(station.positions) - ) and self._has_cargo(): - self.wait_timer = 1 - self._run_after_wait = self._create_unload_cargo_method(factories[0]) - is_finished = False - for mine in self.grid.adjacent_mines(station.positions): - if self._has_space() and mine.cargo_count > 0: - mine.remove_cargo(1) - self.wait_timer = 1 - self._run_after_wait = self._create_load_cargo_method(mine.cargo_type) - is_finished = False - break - if is_finished: - self.continue_to_next_station() + for building in self.grid.adjacent_buildings(station.positions): + for cargo_type in building.accepts: + if self._has_cargo(cargo_type): + self.wait_timer = 1 + self._run_after_wait = self._create_unload_cargo_method( + building, cargo_type + ) + return + for cargo_type in building.produces: + if building.cargo_count[cargo_type] > 0 and self._has_space(cargo_type): + building.remove_cargo(cargo_type, 1) + self.wait_timer = 1 + self._run_after_wait = self._create_load_cargo_method(cargo_type) + return + self.continue_to_next_station() # This ensures that the train can immediately reverse at the station # Otherwise it the train would prefer to continue forward and then reverse # self.current_rail = None - def _has_cargo(self): - return any(wagon.cargo_count for wagon in self.wagons) - - def _create_unload_cargo_method(self, factory: Factory): + def _create_unload_cargo_method(self, factory: Building, cargo_type: CargoType): def inner(): for wagon in reversed(self.wagons): - if wagon.cargo_count: - self.player.score += wagon.cargo_count - factory.cargo_count[wagon.cargo_type] += 1 - wagon.cargo_count = 0 + if wagon.cargo_count[cargo_type]: + self.player.score += wagon.cargo_count[cargo_type] + factory.cargo_count[cargo_type] += wagon.cargo_count[cargo_type] + wagon.cargo_count[cargo_type] = 0 return return inner() - def _has_space(self): - return any(not wagon.cargo_count for wagon in self.wagons) + def _has_space(self, cargo_type: CargoType): + for wagon in self.wagons: + if cargo_type in wagon.cargo_types and not wagon.cargo_count[cargo_type]: + return True + return False + + def _has_cargo(self, cargo_type: CargoType): + return any(wagon.cargo_count[cargo_type] for wagon in self.wagons) def _create_load_cargo_method(self, cargo_type: CargoType): def inner(): for wagon in self.wagons: - if not wagon.cargo_count: - wagon.cargo_type = cargo_type - wagon.cargo_count = 1 + if ( + cargo_type in wagon.cargo_types + and not wagon.cargo_count[cargo_type] + ): + wagon.cargo_count[cargo_type] = 1 return return inner diff --git a/trainfinity2/wagon.py b/trainfinity2/wagon.py index 69d51ec..ea82b6e 100644 --- a/trainfinity2/wagon.py +++ b/trainfinity2/wagon.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from collections import defaultdict +from dataclasses import dataclass, field from trainfinity2.model import CargoType @@ -7,6 +8,10 @@ class Wagon: x: float y: float - cargo_type: CargoType = CargoType.IRON - cargo_count: int = 0 + cargo_types: set[CargoType] = field(init=False) + cargo_count: dict[CargoType, int] = field(init=False) angle: float = 0 + + def __post_init__(self): + self.cargo_types = {CargoType.IRON, CargoType.COAL} + self.cargo_count = defaultdict(int)