diff --git a/github/github.py b/github/github.py index 2ce1d088..38a1db4f 100644 --- a/github/github.py +++ b/github/github.py @@ -23,10 +23,12 @@ SOFTWARE. """ +import asyncio +import logging import re import typing from datetime import datetime, timezone -from typing import Final, List +from typing import Final, List, Optional from urllib.parse import urlparse import aiohttp @@ -40,17 +42,19 @@ from .converters import ExplicitNone +log = logging.getLogger("red.github") + # Constants COLOR = 0x7289DA TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" +EMBED_DESC_LIMIT = 4000 # Regular expressions TOKEN_REGEX: re.Pattern = re.compile(r"token=(.*)") COMMIT_REGEX: re.Pattern = re.compile(r"https://github\.com/.*?/.*?/commit/(.*?)") USER_REPO_BRANCH_REGEX: re.Pattern = re.compile(r"/(.*?)/(.*?)/?(commits)?/(.*?(?=\.atom))?") -LONG_COMMIT_REGEX: re.Pattern = re.compile(r"^|$|(?<=\n)\n+") -LONG_RELEASE_REGEX: re.Pattern = re.compile(r"^|$|(?<=\n)\n+") +LONG_CONTENT_REGEX: re.Pattern = re.compile(r"^|$|(?<=\n)\n+") # Error messages NO_ROLE = "You do not have the required role!" @@ -88,6 +92,8 @@ def __init__(self, bot): self.config.register_guild(**default_guild) self.config.register_member(**default_member) + self._session: Optional[aiohttp.ClientSession] = None + self._github_rss.start() def format_help_for_context(self, ctx: commands.Context) -> str: @@ -97,8 +103,12 @@ def format_help_for_context(self, ctx: commands.Context) -> str: def cog_unload(self): self._github_rss.cancel() + if self._session and not self._session.closed: + self.bot.loop.create_task(self._session.close()) async def initialize(self): + self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=15)) + global_conf = await self.config.all() # Change loop interval if necessary @@ -121,7 +131,7 @@ async def initialize(self): ).feeds() as member_feeds: # Loop through each feed for feed_name, feed_data in member_data.items(): - user, repo, branch, token = await self._parse_url(feed_data["url"]) + user, repo, branch, token = self._parse_url(feed_data["url"]) member_feeds[feed_name] = { "user": user, "repo": repo, @@ -144,15 +154,15 @@ def _escape(text: str): return escape(text, formatting=True) @staticmethod - async def _repo_url(**user_and_repo): + def _repo_url(**user_and_repo): return f"https://github.com/{user_and_repo['user']}/{user_and_repo['repo']}/" @staticmethod - async def _invalid_url(ctx: commands.Context): + def _invalid_url(ctx: commands.Context): return f"Invalid GitHub URL. Try doing `{ctx.clean_prefix}github whatlinks` to see the accepted formats." @staticmethod - async def _url_from_config(feed_config: dict): + def _url_from_config(feed_config: dict): final_url = f"https://github.com/{feed_config['user']}/{feed_config['repo']}" if feed_config["branch"]: @@ -164,19 +174,21 @@ async def _url_from_config(feed_config: dict): return final_url + f"/commits/{feed_config['branch']}.atom{token}" else: - return final_url + f"/commits.atom" + return final_url + "/commits.atom" - @staticmethod - async def _fetch(url: str, valid_statuses: list): - async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: + async def _fetch(self, url: str, valid_statuses: list): + try: + async with self._session.get(url) as resp: html = await resp.read() if resp.status not in valid_statuses: return False + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + log.debug("HTTP error fetching %s: %s", url, exc) + return False return feedparser.parse(html) @staticmethod - async def new_entries(entries, last_time): + def new_entries(entries, last_time): entries_new = [] for e in entries: e_time = ( @@ -189,7 +201,7 @@ async def new_entries(entries, last_time): return entries_new, datetime.now(tz=timezone.utc) @staticmethod - async def _parse_url(url: str): + def _parse_url(url: str): # Strip if url[0] == "<" and url[-1] == ">": url = url[1:-1] @@ -227,9 +239,9 @@ async def _parse_url(url: str): return user, repo, branch, token - async def _parse_url_input(self, url: str, branch: str): - user, repo, parsed_branch, token = await self._parse_url(url) - if not any([user, repo, parsed_branch, token]): + def _parse_url_input(self, url: str, branch: Optional[str]) -> Optional[dict]: + user, repo, parsed_branch, token = self._parse_url(url) + if not user or not repo: return None return { @@ -253,13 +265,13 @@ async def _get_feed_channel(self, bot: discord.Member, guild_channel: int, feed_ channel = None return channel - async def _commit_embeds( + def _commit_embeds( self, entries: list, feed_link: str, color: int, timestamp: bool, short: bool ): if not entries: return None - user, repo, branch, __ = await self._parse_url(feed_link + ".atom") + user, repo, branch, __ = self._parse_url(feed_link + ".atom") if branch == "releases": embed = discord.Embed( @@ -268,19 +280,34 @@ async def _commit_embeds( url=entries[0].link, ) if not short: - embed.description = html2text.html2text(entries[0].content[0].value) + content_list = getattr(entries[0], "content", []) + if content_list: + raw = html2text.html2text(content_list[0].value) + embed.description = ( + raw[:EMBED_DESC_LIMIT] + "…" if len(raw) > EMBED_DESC_LIMIT else raw + ) else: num = min(len(entries), 10) desc = "" for e in entries[:num]: + commit_match = COMMIT_REGEX.fullmatch(e.link) + sha = commit_match.group(1)[:7] if commit_match else e.link[-7:] + if short: - desc += f"[`{COMMIT_REGEX.fullmatch(e.link).group(1)[:7]}`]({e.link}) {self._escape(e.title)} – {self._escape(e.author)}\n" + desc += ( + f"[`{sha}`]({e.link}) {self._escape(e.title)} – {self._escape(e.author)}\n" + ) else: - desc += f"[`{COMMIT_REGEX.fullmatch(e.link).group(1)[:7]}`]({e.link}) – {self._escape(e.author)}\n{LONG_COMMIT_REGEX.sub('', e.content[0].value)}\n\n" + content = getattr(e, "content", []) + body = LONG_CONTENT_REGEX.sub("", content[0].value) if content else "" + desc += f"[`{sha}`]({e.link}) – {self._escape(e.author)}\n{body}\n\n" + + if len(desc) > EMBED_DESC_LIMIT: + desc = desc[:EMBED_DESC_LIMIT] + "…" embed = discord.Embed( - title=f"[{repo}:{branch}] {num} new commit{'s' if num > 1 else ''}", + title=f"[{repo}:{branch or 'default'}] {num} new commit{'s' if num > 1 else ''}", color=color if color is not None else COLOR, description=desc, url=feed_link if num > 1 else entries[0].link, @@ -291,10 +318,12 @@ async def _commit_embeds( tzinfo=timezone.utc ) + thumbnails = getattr(entries[0], "media_thumbnail", []) + icon_url = thumbnails[0]["url"] if thumbnails else None embed.set_author( name=entries[0].author, url=f"https://github.com/{entries[0].author}", - icon_url=entries[0].media_thumbnail[0]["url"], + icon_url=icon_url, ) return embed @@ -372,7 +401,7 @@ async def _set_role(self, ctx: commands.Context, role: discord.Role = None): """ if not role: await self.config.guild(ctx.guild).role.set(None) - return await ctx.send(f"The GitHub RSS feed role requirement has been removed.") + return await ctx.send("The GitHub RSS feed role requirement has been removed.") else: await self.config.guild(ctx.guild).role.set(role.id) return await ctx.send(f"The GitHub RSS feed role has been set to {role.mention}.") @@ -401,9 +430,9 @@ async def _force(self, ctx: commands.Context, user: discord.Member, name: str): if not (feed_config := feeds.get(name)): return await ctx.send(NOT_FOUND) - url = await self._url_from_config(feed_config) + url = self._url_from_config(feed_config) if not (parsed := await self._fetch(url, [200])): - return await ctx.send(await self._invalid_url(ctx)) + return await ctx.send(self._invalid_url(ctx)) if feed_config["channel"]: channel = ctx.guild.get_channel(feed_config["channel"]) @@ -414,7 +443,7 @@ async def _force(self, ctx: commands.Context, user: discord.Member, name: str): if channel and channel.permissions_for(ctx.guild.me).embed_links: return await channel.send( - embed=await self._commit_embeds( + embed=self._commit_embeds( entries=[parsed.entries[0]], feed_link=parsed.feed.link, color=guild_config["color"], @@ -428,10 +457,10 @@ async def _force(self, ctx: commands.Context, user: discord.Member, name: str): ) @_github_set.command(name="forceall") - async def _force_all(self, ctx: commands.context): + async def _force_all(self, ctx: commands.Context): """Force a run of the GitHub feed fetching coroutine.""" async with ctx.typing(): - await self._github_rss.coro(self, guild_to_check=ctx.guild.id) + await self._do_rss_check(guild_to_check=ctx.guild.id) return await ctx.tick() @_github_set.command(name="rename") @@ -488,7 +517,7 @@ async def _list_all(self, ctx: commands.Context): continue feeds_string += f"{(await self.bot.get_or_fetch_user(member_id)).mention}: `{len(member_data['feeds'])}` feed(s) \n" for name, feed in member_data["feeds"].items(): - feeds_string += f"- `{name}`: <{await self._repo_url(**feed)}>\n" + feeds_string += f"- `{name}`: <{self._repo_url(**feed)}>\n" feeds_string += "\n" if not feeds_string: @@ -560,18 +589,18 @@ async def _get( ): """Test out fetching a GitHub repository url.""" - if not (user_repo_branch_token := await self._parse_url_input(url, branch)): - return await ctx.send(await self._invalid_url(ctx)) + if not (user_repo_branch_token := self._parse_url_input(url, branch)): + return await ctx.send(self._invalid_url(ctx)) - url = await self._url_from_config(user_repo_branch_token) + url = self._url_from_config(user_repo_branch_token) if not (parsed := await self._fetch(url, [200])): - return await ctx.send(await self._invalid_url(ctx)) + return await ctx.send(self._invalid_url(ctx)) guild_config = await self.config.guild(ctx.guild).all() return await ctx.send( - embed=await self._commit_embeds( + embed=self._commit_embeds( entries=parsed.entries[:entries] if entries else [parsed.entries[0]], feed_link=parsed.feed.link, color=guild_config["color"], @@ -581,7 +610,7 @@ async def _get( ) @_github.command(name="add") - async def _add(self, ctx: commands.Context, name: str, url: str, branch: str = ""): + async def _add(self, ctx: commands.Context, name: str, url: str, branch: Optional[str] = None): """ Add a GitHub RSS feed to the server. @@ -605,13 +634,13 @@ async def _add(self, ctx: commands.Context, name: str, url: str, branch: str = " return await ctx.send("The mods have not set up a GitHub RSS feed channel yet.") # Get RSS feed url - if not (user_repo_branch_token := await self._parse_url_input(url, branch)): - return await ctx.send(await self._invalid_url(ctx)) - url = await self._url_from_config(user_repo_branch_token) + if not (user_repo_branch_token := self._parse_url_input(url, branch)): + return await ctx.send(self._invalid_url(ctx)) + url = self._url_from_config(user_repo_branch_token) # Fetch and parse if not (parsed := await self._fetch(url, [200, 304])): - return await ctx.send(await self._invalid_url(ctx)) + return await ctx.send(self._invalid_url(ctx)) # Set user config async with self.config.member(ctx.author).feeds() as feeds: @@ -647,13 +676,13 @@ async def _add(self, ctx: commands.Context, name: str, url: str, branch: str = " await channel.send( embed=discord.Embed( color=discord.Color.green(), - description=f"[[{user_repo_branch_token['repo']}:{(await self._parse_url(parsed.feed.link+'.atom'))[2]}]]({await self._repo_url(**user_repo_branch_token)}) has been added by {ctx.author.mention}", + description=f"[[{user_repo_branch_token['repo']}:{self._parse_url(parsed.feed.link + '.atom')[2]}]]({self._repo_url(**user_repo_branch_token)}) has been added by {ctx.author.mention}", ) ) # Send last feed entry await channel.send( - embed=await self._commit_embeds( + embed=self._commit_embeds( entries=[parsed.entries[0]], feed_link=parsed.feed.link, color=guild_config["color"], @@ -687,12 +716,13 @@ async def _remove(self, ctx: commands.Context, name: str): channel = await self._get_feed_channel( ctx.guild.me, guild_config["channel"], to_remove["channel"] ) - await channel.send( - embed=discord.Embed( - color=discord.Color.red(), - description=f"[[{to_remove['repo']}:{to_remove['branch'] or 'main'}]]({await self._repo_url(**to_remove)}) has been removed by {ctx.author.mention}", + if channel: + await channel.send( + embed=discord.Embed( + color=discord.Color.red(), + description=f"[[{to_remove['repo']}:{to_remove['branch'] or 'default'}]]({self._repo_url(**to_remove)}) has been removed by {ctx.author.mention}", + ) ) - ) return await ctx.send("Feed successfully removed.") @@ -705,9 +735,9 @@ async def _list(self, ctx: commands.Context): return await ctx.send(NO_ROLE) feeds_string = "" - async with self.config.member(ctx.author).feeds() as feeds: - for name, feed in feeds.items(): - feeds_string += f"`{name}`: <{await self._repo_url(**feed)}>\n" + feeds = await self.config.member(ctx.author).feeds() + for name, feed in feeds.items(): + feeds_string += f"`{name}`: <{self._repo_url(**feed)}>\n" if not feeds_string: return await ctx.send( @@ -722,8 +752,7 @@ async def _list(self, ctx: commands.Context): for embed in embeds: await ctx.send(embed=embed) - @tasks.loop(minutes=3) - async def _github_rss(self, guild_to_check=None): + async def _do_rss_check(self, guild_to_check: Optional[int] = None) -> None: # Loop through each guild for guild_id, guild_config in (await self.config.all_guilds()).items(): @@ -752,43 +781,76 @@ async def _github_rss(self, guild_to_check=None): # Loop through each feed for name, feed in member_data["feeds"].items(): - url = await self._url_from_config(feed) - - # Fetch & parse feed - if not (parsed := await self._fetch(url, [200])): - continue - - # Find new entries - new_entries, new_time = await self.new_entries(parsed.entries, feed["time"]) - - # Create feed embed - if e := await self._commit_embeds( - entries=new_entries, - feed_link=parsed.feed.link, - color=guild_config["color"], - timestamp=guild_config["timestamp"], - short=guild_config["short"], - ): - - # Get channel (guild vs feed override) - ch = channel - if feed["channel"]: - if not ( - (ch := guild.get_channel(feed["channel"])) - and ch.permissions_for(guild.me).send_messages - and ch.permissions_for(guild.me).embed_links - ): - ch = None - - # Send feed embed - if ch: - await ch.send(embed=e) - - # Set time to feed config - async with self.config.member_from_ids( - guild_id, member_id - ).feeds() as member_feeds: - member_feeds[name]["time"] = new_time.timestamp() + try: + url = self._url_from_config(feed) + + # Fetch & parse feed + if not (parsed := await self._fetch(url, [200])): + continue + + # Find new entries + new_entries, new_time = self.new_entries(parsed.entries, feed["time"]) + + # Create feed embed + if e := self._commit_embeds( + entries=new_entries, + feed_link=parsed.feed.link, + color=guild_config["color"], + timestamp=guild_config["timestamp"], + short=guild_config["short"], + ): + + # Get channel (guild vs feed override) + ch = channel + if feed["channel"]: + if not ( + (ch := guild.get_channel(feed["channel"])) + and ch.permissions_for(guild.me).send_messages + and ch.permissions_for(guild.me).embed_links + ): + ch = None + + # Send feed embed + if ch: + await ch.send(embed=e) + + # Set time to feed config + async with self.config.member_from_ids( + guild_id, member_id + ).feeds() as member_feeds: + member_feeds[name]["time"] = new_time.timestamp() + + except discord.Forbidden: + log.warning( + "Missing permissions for feed '%s' (guild %s, member %s).", + name, + guild_id, + member_id, + ) + except discord.HTTPException as exc: + log.error( + "HTTP error posting feed '%s' (guild %s, member %s): %s", + name, + guild_id, + member_id, + exc, + ) + except Exception: + log.exception( + "Unexpected error processing feed '%s' (guild %s, member %s).", + name, + guild_id, + member_id, + ) + + @tasks.loop(minutes=3) + async def _github_rss(self): + await self._do_rss_check() + + @_github_rss.error + async def _github_rss_error(self, error: Exception): + log.error("GitHub RSS task loop encountered an unhandled error.", exc_info=error) + self._github_rss.restart() @_github_rss.before_loop async def _before_github_rss(self):