diff --git a/royalnet/campaigns/asynccampaign.py b/royalnet/campaigns/asynccampaign.py index 03736c4e..14ac6a44 100644 --- a/royalnet/campaigns/asynccampaign.py +++ b/royalnet/campaigns/asynccampaign.py @@ -20,7 +20,7 @@ class AsyncCampaign: An AsyncCampaign consists of multiple chained AsyncAdventures, which are AsyncGenerators yielding tuples with an AsyncChallenge and optional data. """ - def __init__(self, start: AsyncAdventure, *args, **kwargs): + def __init__(self, start: AsyncAdventure, challenge: Optional[AsyncChallenge], *args, **kwargs): """ Initialize an AsyncCampaign object. @@ -29,20 +29,21 @@ class AsyncCampaign: :param start: The starting adventure for the AsyncCampaign. """ self.adventure: AsyncAdventure = start - self.challenge: AsyncChallenge = TrueAsyncChallenge() + self.challenge: AsyncChallenge = challenge or TrueAsyncChallenge() self.last_update: datetime.datetime = ... @classmethod - async def create(cls, start: AsyncAdventure, *args, **kwargs) -> Tuple[AsyncCampaign, ...]: + async def create(cls, start: AsyncAdventure, challenge: Optional[AsyncChallenge] = None, *args, **kwargs) -> AsyncCampaign: """ Create a new AsyncCampaign object. :param start: The starting Adventure for the AsyncCampaign. - :return: A tuple containing the created AsyncCampaign and optionally a list of extra output. + :param challenge: The AsyncChallenge the campaign should start with. + :return: The created AsyncCampaign. """ - campaign = cls(start=start, *args, **kwargs) - output = await campaign.next() - return campaign, *output + campaign = cls(start=start, challenge=challenge, *args, **kwargs) + await campaign._asend(None) + return campaign async def _asend(self, data: Any) -> Any: try: @@ -80,6 +81,7 @@ class AsyncCampaign: if inspect.isasyncgen(result): await self._aclose() self.adventure = result + await self._asend(None) return await self.next(data) elif isinstance(result, AsyncChallenge): self.challenge = result diff --git a/royalnet/campaigns/campaign.py b/royalnet/campaigns/campaign.py index 0fc95271..015c8d4f 100644 --- a/royalnet/campaigns/campaign.py +++ b/royalnet/campaigns/campaign.py @@ -21,29 +21,28 @@ class Campaign: optional data. """ - def __init__(self, start: Adventure, *args, **kwargs): + def __init__(self, start: Adventure, challenge: Optional[Challenge] = None, *args, **kwargs): """ Initialize a Campaign object. .. warning:: Do not use this, use the Campaign.create() method instead! - - :param start: The starting adventure for the Campaign. """ self.adventure: Adventure = start - self.challenge: Challenge = TrueChallenge() + self.challenge: Challenge = challenge or TrueChallenge() self.last_update: datetime.datetime = ... @classmethod - def create(cls, start: Adventure, *args, **kwargs) -> Tuple[Campaign, ...]: + def create(cls, start: Adventure, challenge: Optional[Challenge] = None, *args, **kwargs) -> Campaign: """ Create a new Campaign object. :param start: The starting Adventure for the Campaign. - :return: A tuple containing the created Campaign and optionally a list of extra output. + :param challenge: The Challenge the campaign should start with. + :return: The created Campaign. """ - campaign = cls(start=start, *args, **kwargs) - output = campaign.next() - return campaign, *output + campaign = cls(start=start, challenge=challenge, *args, **kwargs) + campaign.adventure.send(None) + return campaign def next(self, data: Any = None) -> List: """ @@ -60,6 +59,7 @@ class Campaign: if inspect.isgenerator(result): self.adventure.close() self.adventure = result + self.adventure.send(None) return self.next(data) elif isinstance(result, Challenge): self.challenge = result diff --git a/royalnet/campaigns/tests/test_asynccampaign.py b/royalnet/campaigns/tests/test_asynccampaign.py index 06be24ee..af277fca 100644 --- a/royalnet/campaigns/tests/test_asynccampaign.py +++ b/royalnet/campaigns/tests/test_asynccampaign.py @@ -7,20 +7,19 @@ from ..exc import * @pytest.mark.asyncio async def test_creation(): async def gen(): - yield None, "Created!" + yield - campaign, data = await AsyncCampaign.create(start=gen()) - assert data == "Created!" + await AsyncCampaign.create(start=gen()) @pytest.mark.asyncio async def test_send_receive(): async def gen(): - ping = yield None + ping = yield assert ping == "Ping!" yield None, "Pong!" - campaign, = await AsyncCampaign.create(start=gen()) + campaign = await AsyncCampaign.create(start=gen()) pong, = await campaign.next("Ping!") assert pong == "Pong!" @@ -33,9 +32,9 @@ class FalseChallenge(AsyncChallenge): @pytest.mark.asyncio async def test_failing_check(): async def gen(): - yield FalseChallenge() + yield - campaign, = await AsyncCampaign.create(start=gen()) + campaign = await AsyncCampaign.create(start=gen(), challenge=FalseChallenge()) with pytest.raises(ChallengeFailedError): await campaign.next() @@ -43,14 +42,14 @@ async def test_failing_check(): @pytest.mark.asyncio async def test_switching(): async def gen_1(): + yield yield gen_2() async def gen_2(): - yield None, "Post-init!" + yield yield None, "Second message!" yield None - campaign, data = await AsyncCampaign.create(start=gen_1()) - assert data == "Post-init!" + campaign = await AsyncCampaign.create(start=gen_1()) data, = await campaign.next() assert data == "Second message!" diff --git a/royalnet/campaigns/tests/test_campaign.py b/royalnet/campaigns/tests/test_campaign.py index c86e371d..66cbaada 100644 --- a/royalnet/campaigns/tests/test_campaign.py +++ b/royalnet/campaigns/tests/test_campaign.py @@ -1,24 +1,23 @@ import pytest from ..campaign import Campaign -from ..challenge import Challenge +from ..challenge import Challenge, TrueChallenge from ..exc import * def test_creation(): def gen(): - yield None, "Created!" + yield - campaign, data = Campaign.create(start=gen()) - assert data == "Created!" + Campaign.create(start=gen()) def test_send_receive(): def gen(): - ping = yield None + ping = yield assert ping == "Ping!" yield None, "Pong!" - campaign, = Campaign.create(start=gen()) + campaign = Campaign.create(start=gen()) pong, = campaign.next("Ping!") assert pong == "Pong!" @@ -30,23 +29,23 @@ class FalseChallenge(Challenge): def test_failing_check(): def gen(): - yield FalseChallenge() + yield - campaign, = Campaign.create(start=gen()) + campaign = Campaign.create(start=gen(), challenge=FalseChallenge()) with pytest.raises(ChallengeFailedError): campaign.next() def test_switching(): def gen_1(): + yield yield gen_2() def gen_2(): - yield None, "Post-init!" + yield yield None, "Second message!" yield None - campaign, data = Campaign.create(start=gen_1()) - assert data == "Post-init!" + campaign = Campaign.create(start=gen_1()) data, = campaign.next() assert data == "Second message!"