diff --git a/JetHerald/Commands/ListCommand.cs b/JetHerald/Commands/ListCommand.cs index 5bcfdf1..871af42 100644 --- a/JetHerald/Commands/ListCommand.cs +++ b/JetHerald/Commands/ListCommand.cs @@ -21,7 +21,8 @@ public class ListCommand : IChatCommand var msg = update.Message; var chatid = msg.Chat.Id; - var topics = await Db.GetTopicsForSub(NamespacedId.Telegram(chatid)); + using var ctx = await Db.GetContext(); + var topics = await ctx.GetTopicsForSub(NamespacedId.Telegram(chatid)); return topics.Any() ? "Topics:\n" + string.Join("\n", topics) diff --git a/JetHerald/Commands/SubscribeCommand.cs b/JetHerald/Commands/SubscribeCommand.cs index fc21366..35c83c3 100644 --- a/JetHerald/Commands/SubscribeCommand.cs +++ b/JetHerald/Commands/SubscribeCommand.cs @@ -25,7 +25,8 @@ public class SubscribeCommand : IChatCommand var chat = NamespacedId.Telegram(args.Message.Chat.Id); var token = cmd.Parameters[0]; - var topic = await Db.GetTopicForSub(token, chat); + using var ctx = await Db.GetContext(); + var topic = await ctx.GetTopicForSub(token, chat); if (topic == null) return "topic not found"; @@ -35,7 +36,8 @@ public class SubscribeCommand : IChatCommand return "token mismatch"; else { - await Db.CreateSubscription(topic.TopicId, chat); + await ctx.CreateSubscription(topic.TopicId, chat); + ctx.Commit(); return $"subscribed to {topic.Name}"; } } diff --git a/JetHerald/Commands/UnsubscribeCommand.cs b/JetHerald/Commands/UnsubscribeCommand.cs index 99f4970..08fccfa 100644 --- a/JetHerald/Commands/UnsubscribeCommand.cs +++ b/JetHerald/Commands/UnsubscribeCommand.cs @@ -26,7 +26,9 @@ public class UnsubscribeCommand : IChatCommand var chat = NamespacedId.Telegram(msg.Chat.Id); var topicName = cmd.Parameters[0]; - int affected = await Db.RemoveSubscription(topicName, chat); + using var ctx = await Db.GetContext(); + int affected = await ctx.RemoveSubscription(topicName, chat); + ctx.Commit(); if (affected >= 1) return $"unsubscribed from {topicName}"; else diff --git a/JetHerald/Controllers/Api/HeartbeatController.cs b/JetHerald/Controllers/Api/HeartbeatController.cs index 6258944..a84f810 100644 --- a/JetHerald/Controllers/Api/HeartbeatController.cs +++ b/JetHerald/Controllers/Api/HeartbeatController.cs @@ -67,7 +67,8 @@ public class HeartbeatController : ControllerBase { var heart = args.Heart ?? "General"; - var t = await Db.GetTopic(args.Topic); + var ctx = await Db.GetContext(); + var t = await ctx.GetTopic(args.Topic); if (t == null) return new NotFoundResult(); else if (!t.WriteToken.Equals(args.WriteToken, StringComparison.Ordinal)) @@ -76,8 +77,8 @@ public class HeartbeatController : ControllerBase if (Timeouts.IsTimedOut(t.TopicId)) return StatusCode(StatusCodes.Status429TooManyRequests); - var wasBeating = await Db.ReportHeartbeat(t.TopicId, heart, args.ExpiryTimeout); - + var wasBeating = await ctx.ReportHeartbeat(t.TopicId, heart, args.ExpiryTimeout); + ctx.Commit(); if (wasBeating == 0) await Herald.BroadcastMessageRaw(t.TopicId, $"!{t.Description}!:\nHeart \"{heart}\" has started beating at {DateTime.UtcNow:O}"); diff --git a/JetHerald/Controllers/Api/ReportController.cs b/JetHerald/Controllers/Api/ReportController.cs index 6c9b24e..341c546 100644 --- a/JetHerald/Controllers/Api/ReportController.cs +++ b/JetHerald/Controllers/Api/ReportController.cs @@ -52,7 +52,12 @@ public class ReportController : ControllerBase private async Task DoReport(ReportArgs args) { - var t = await Db.GetTopic(args.Topic); + Contracts.Topic t; + using (var ctx = await Db.GetContext()) + { + t = await ctx.GetTopic(args.Topic); + } + if (t == null) return new NotFoundResult(); else if (!t.WriteToken.Equals(args.WriteToken, StringComparison.OrdinalIgnoreCase)) diff --git a/JetHerald/Controllers/Ui/AdminToolsController.cs b/JetHerald/Controllers/Ui/AdminToolsController.cs index 517d9a6..c3937f8 100644 --- a/JetHerald/Controllers/Ui/AdminToolsController.cs +++ b/JetHerald/Controllers/Ui/AdminToolsController.cs @@ -31,9 +31,10 @@ public class AdminToolsController : Controller [HttpGet, Route("ui/admintools/invites")] public async Task ViewInvites() { - var invites = await Db.GetInvites(); - var plans = await Db.GetPlans(); - var roles = await Db.GetRoles(); + using var ctx = await Db.GetContext(); + var invites = await ctx.GetInvites(); + var plans = await ctx.GetPlans(); + var roles = await ctx.GetRoles(); return View(new ViewInvitesModel { Invites = invites.ToArray(), @@ -52,7 +53,9 @@ public class AdminToolsController : Controller [HttpPost, Route("ui/admintools/invites/create")] public async Task CreateInvite(CreateInviteRequest req) { - await Db.CreateUserInvite(req.PlanId, req.RoleId, TokenHelper.GetToken(AuthCfg.InviteCodeLength)); + using var ctx = await Db.GetContext(); + await ctx.CreateUserInvite(req.PlanId, req.RoleId, TokenHelper.GetToken(AuthCfg.InviteCodeLength)); + ctx.Commit(); return RedirectToAction(nameof(ViewInvites)); } } diff --git a/JetHerald/Controllers/Ui/DashboardController.cs b/JetHerald/Controllers/Ui/DashboardController.cs index 60ba396..51d7f5c 100644 --- a/JetHerald/Controllers/Ui/DashboardController.cs +++ b/JetHerald/Controllers/Ui/DashboardController.cs @@ -21,9 +21,10 @@ public class DashboardController : Controller public async Task Index() { var login = HttpContext.User.GetUserLogin(); - var user = await Db.GetUser(login); - var topics = await Db.GetTopicsForUser(user.UserId); - var hearts = await Db.GetHeartsForUser(user.UserId); + using var ctx = await Db.GetContext(); + var user = await ctx.GetUser(login); + var topics = await ctx.GetTopicsForUser(user.UserId); + var hearts = await ctx.GetHeartsForUser(user.UserId); var vm = new DashboardViewModel { Topics = topics.ToArray(), diff --git a/JetHerald/Controllers/Ui/LoginController.cs b/JetHerald/Controllers/Ui/LoginController.cs index e6b9284..88a5b7b 100644 --- a/JetHerald/Controllers/Ui/LoginController.cs +++ b/JetHerald/Controllers/Ui/LoginController.cs @@ -51,7 +51,8 @@ public class LoginController : Controller ViewData["RedirectTo"] = PathStringOrDefault(redirect); - var user = await Db.GetUser(req.Username); + using var ctx = await Db.GetContext(); + var user = await ctx.GetUser(req.Username); if (user == null) { ModelState.AddModelError("", "User not found"); diff --git a/JetHerald/Controllers/Ui/ProfileController.cs b/JetHerald/Controllers/Ui/ProfileController.cs index 3a7a59c..a60de68 100644 --- a/JetHerald/Controllers/Ui/ProfileController.cs +++ b/JetHerald/Controllers/Ui/ProfileController.cs @@ -19,7 +19,8 @@ public class ProfileController : Controller public async Task Index() { var login = HttpContext.User.GetUserLogin(); - var user = await Db.GetUser(login); + using var ctx = await Db.GetContext(); + var user = await ctx.GetUser(login); var vm = new ProfileViewModel { diff --git a/JetHerald/Controllers/Ui/RegistrationController.cs b/JetHerald/Controllers/Ui/RegistrationController.cs index 53f7c2a..f5be007 100644 --- a/JetHerald/Controllers/Ui/RegistrationController.cs +++ b/JetHerald/Controllers/Ui/RegistrationController.cs @@ -69,13 +69,14 @@ public class RegistrationController : Controller ViewData["RedirectTo"] = PathStringOrDefault(redirect); - var oldUser = await Db.GetUser(req.Login); + using var ctx = await Db.GetContext(); + var oldUser = await ctx.GetUser(req.Login); if (oldUser != null) { ModelState.AddModelError("", "User already exists"); return View(); } - var invite = await Db.GetInviteByCode(req.InviteCode); + var invite = await ctx.GetInviteByCode(req.InviteCode); if (invite == null || invite.RedeemedBy != default) { ModelState.AddModelError("", "No unredeemed invite with this code found"); @@ -91,8 +92,9 @@ public class RegistrationController : Controller PasswordSalt = RandomNumberGenerator.GetBytes(64) }; user.PasswordHash = AuthUtils.GetHashFor(req.Password, user.PasswordSalt, user.HashType); - user = await Db.RegisterUser(user); - await Db.RedeemInvite(invite.UserInviteId, user.UserId); + user = await ctx.RegisterUser(user); + await ctx.RedeemInvite(invite.UserInviteId, user.UserId); + ctx.Commit(); var userIdentity = AuthUtils.CreateIdentity(user.UserId, user.Login, user.Name, user.Allow); var principal = new ClaimsPrincipal(userIdentity); await HttpContext.SignInAsync(CookieAuthenticationDefaults.AuthenticationScheme, principal); diff --git a/JetHerald/Controllers/Ui/TopicController.cs b/JetHerald/Controllers/Ui/TopicController.cs index 8152389..7b5b0e6 100644 --- a/JetHerald/Controllers/Ui/TopicController.cs +++ b/JetHerald/Controllers/Ui/TopicController.cs @@ -36,7 +36,9 @@ public class TopicController : Controller if (!ModelState.IsValid) return View(); var userId = HttpContext.User.GetUserId(); - var topic = await Db.CreateTopic(userId, req.Name, req.Description); + using var ctx = await Db.GetContext(); + var topic = await ctx.CreateTopic(userId, req.Name, req.Description); + ctx.Commit(); if (topic == null) { ModelState.AddModelError("", "Unknown error"); @@ -50,11 +52,12 @@ public class TopicController : Controller public async Task ViewTopic(string topicName) { var userId = HttpContext.User.GetUserId(); - var topic = await Db.GetTopic(topicName); + using var ctx = await Db.GetContext(); + var topic = await ctx.GetTopic(topicName); if (topic == null || topic.CreatorId != userId) return NotFound(); - var hearts = await Db.GetHeartsForTopic(topic.TopicId); + var hearts = await ctx.GetHeartsForTopic(topic.TopicId); var vm = new TopicViewModel { Topic = topic, diff --git a/JetHerald/Middlewares/AnonymousUserMassagerMiddleware.cs b/JetHerald/Middlewares/AnonymousUserMassagerMiddleware.cs index f3aea83..e2b071f 100644 --- a/JetHerald/Middlewares/AnonymousUserMassagerMiddleware.cs +++ b/JetHerald/Middlewares/AnonymousUserMassagerMiddleware.cs @@ -12,7 +12,8 @@ public class AnonymousUserMassagerMiddleware : IMiddleware { AnonymousPermissions = new Lazy>(async () => { - var anonymousUser = await db.GetUser("Anonymous"); + using var ctx = await db.GetContext(); + var anonymousUser = await ctx.GetUser("Anonymous"); return anonymousUser.Allow; }); } diff --git a/JetHerald/Program.cs b/JetHerald/Program.cs index dbad9fe..8faa8c5 100644 --- a/JetHerald/Program.cs +++ b/JetHerald/Program.cs @@ -91,12 +91,12 @@ try // preflight checks { var db = app.Services.GetService(); - - var adminUser = await db.GetUser("admin"); + using var ctx = await db.GetContext(); + var adminUser = await ctx.GetUser("admin"); if (adminUser == null) { - var adminRole = (await db.GetRoles()).First(r => r.Name == "admin"); - var unlimitedPlan = (await db.GetPlans()).First(p => p.Name == "unlimited"); + var adminRole = (await ctx.GetRoles()).First(r => r.Name == "admin"); + var unlimitedPlan = (await ctx.GetPlans()).First(p => p.Name == "unlimited"); var authCfg = app.Services.GetService>().Value; var password = Convert.ToBase64String(RandomNumberGenerator.GetBytes(48)); @@ -110,7 +110,8 @@ try PlanId = unlimitedPlan.PlanId }; adminUser.PasswordHash = AuthUtils.GetHashFor(password, adminUser.PasswordSalt, adminUser.HashType); - var newUser = await db.RegisterUser(adminUser); + var newUser = await ctx.RegisterUser(adminUser); + ctx.Commit(); log.Warn($"Created administrative account {adminUser.Login}:{password}. Be sure to save these credentials somewhere!"); } } diff --git a/JetHerald/Services/Db.cs b/JetHerald/Services/Db.cs deleted file mode 100644 index 6d34f47..0000000 --- a/JetHerald/Services/Db.cs +++ /dev/null @@ -1,288 +0,0 @@ -using MySql.Data.MySqlClient; -using Dapper; -using JetHerald.Options; -using JetHerald.Contracts; - -namespace JetHerald.Services; -public class Db -{ - public async Task> GetTopicsForUser(uint userId) - { - using var c = GetConnection(); - return await c.QueryAsync( - " SELECT * FROM topic WHERE CreatorId = @userId", - new { userId }); - } - public async Task> GetPlans() - { - using var c = GetConnection(); - return await c.QueryAsync("SELECT * FROM plan"); - } - - public async Task> GetRoles() - { - using var c = GetConnection(); - return await c.QueryAsync("SELECT * FROM role"); - } - public async Task> GetInvites() - { - using var c = GetConnection(); - return await c.QueryAsync("SELECT * FROM userinvite"); - } - - public async Task> GetHeartsForUser(uint userId) - { - using var c = GetConnection(); - return await c.QueryAsync( - " SELECT h.* FROM heart h JOIN topic USING (TopicId) WHERE CreatorId = @userId", - new { userId }); - } - - public async Task CreateUserInvite(uint planId, uint roleId, string inviteCode) - { - using var c = GetConnection(); - await c.ExecuteAsync(@" - INSERT INTO userinvite - ( PlanId, RoleId, InviteCode) - VALUES - (@planId, @roleId, @inviteCode)", - new { planId, roleId, inviteCode }); - } - - public async Task GetTopic(string name) - { - using var c = GetConnection(); - return await c.QuerySingleOrDefaultAsync( - "SELECT * FROM topic WHERE Name = @name", - new { name }); - } - - public async Task DeleteTopic(string name, uint userId) - { - using var c = GetConnection(); - return await c.ExecuteAsync( - " DELETE FROM topic WHERE Name = @name AND CreatorId = @userId", - new { name, userId }); - } - - - public async Task GetTopicForSub(string token, NamespacedId sub) - { - using var c = GetConnection(); - return await c.QuerySingleOrDefaultAsync( - " SELECT t.*, ts.Sub " + - " FROM topic t " + - " LEFT JOIN topic_sub ts ON t.TopicId = ts.TopicId AND ts.Sub = @sub " + - " WHERE ReadToken = @token", - new { token, sub }); - } - - public async Task> GetHeartsForTopic(uint topicId) - { - using var c = GetConnection(); - return await c.QueryAsync( - " SELECT * FROM heart WHERE TopicId = @topicId", - new { topicId }); - } - public async Task GetUser(string login) - { - using var c = GetConnection(); - return await c.QuerySingleOrDefaultAsync(@" - SELECT u.*, up.*, ur.* - FROM user u - JOIN plan up ON u.PlanId = up.PlanId - JOIN role ur ON u.RoleId = ur.RoleId - WHERE u.Login = @login;", - new { login }); - } - - public async Task CreateTopic(uint user, string name, string descr) - { - using var c = GetConnection(); - - await c.OpenAsync(); - - await using var tx = await c.BeginTransactionAsync(); - - var topicsCount = await c.QuerySingleAsync( - " SELECT COUNT(*) " + - " FROM user u " + - " LEFT JOIN topic t ON t.CreatorId = u.UserId " + - " WHERE u.UserId = @user", - new { user }, - transaction: tx - ); - - var planTopicsCount = await c.QuerySingleAsync( - " SELECT p.MaxTopics " + - " FROM user u " + - " LEFT JOIN plan p ON p.PlanId = u.PlanId " + - " WHERE u.UserId = @user", - new { user }, - transaction: tx - ); - - if (topicsCount >= planTopicsCount) return null; - - var topic = await c.QuerySingleOrDefaultAsync( - " INSERT INTO topic " + - " ( CreatorId, Name, Description, ReadToken, WriteToken) " + - " VALUES " + - " (@CreatorId, @Name, @Description, @ReadToken, @WriteToken); " + - " SELECT * FROM topic WHERE TopicId = LAST_INSERT_ID(); ", - new Topic - { - CreatorId = user, - Name = name, - Description = descr, - ReadToken = TokenHelper.GetToken(), - WriteToken = TokenHelper.GetToken() - }, transaction: tx); - - await tx.CommitAsync(); - - return topic; - } - - public async Task RegisterUser(User user) - { - using var c = GetConnection(); - uint userId = await c.QuerySingleOrDefaultAsync(@" - INSERT INTO user - ( Login, Name, PasswordHash, PasswordSalt, HashType, PlanId, RoleId) - VALUES - (@Login, @Name, @PasswordHash, @PasswordSalt, @HashType, @PlanId, @RoleId);", - param:user); - return await GetUser(user.Login); - } - - public async Task RedeemInvite(uint inviteId, uint userId) - { - using var c = GetConnection(); - await c.ExecuteAsync( - @"UPDATE userinvite SET RedeemedBy = @userId WHERE UserInviteId = @inviteId", - new { inviteId, userId }); - } - - public async Task GetInviteByCode(string inviteCode) - { - using var c = GetConnection(); - return await c.QuerySingleOrDefaultAsync( - " SELECT * FROM userinvite " + - " WHERE InviteCode = @inviteCode " + - " AND RedeemedBy IS NULL ", - new { inviteCode }); - } - - public async Task> GetSubsForTopic(uint topicId) - { - using var c = GetConnection(); - return await c.QueryAsync( - " SELECT Sub " + - " FROM topic_sub " + - " WHERE TopicId = @topicid", - new { topicId }); - } - - public async Task> GetTopicsForSub(NamespacedId sub) - { - using var c = GetConnection(); - return await c.QueryAsync( - " SELECT t.*" + - " FROM topic_sub ts" + - " JOIN topic t USING (TopicId)" + - " WHERE ts.Sub = @sub", - new { sub }); - } - - public async Task CreateSubscription(uint topicId, NamespacedId sub) - { - using var c = GetConnection(); - await c.ExecuteAsync( - " INSERT INTO topic_sub" + - " (TopicId, Sub)" + - " VALUES" + - " (@topicId, @sub)", - new { topicId, sub }); - } - - public async Task RemoveSubscription(string topicName, NamespacedId sub) - { - using var c = GetConnection(); - return await c.ExecuteAsync( - " DELETE ts " + - " FROM topic_sub ts" + - " JOIN topic t USING (TopicId) " + - " WHERE t.Name = @topicName AND ts.Sub = @sub;", - new { topicName, sub }); - } - - - public async Task ReportHeartbeat(uint topicId, string heart, int timeoutSeconds) - { - using var c = GetConnection(); - return await c.QueryFirstAsync( - @"CALL report_heartbeat(@topicId, @heart, @timeoutSeconds);", - new { topicId, heart, timeoutSeconds }); - } - - public async Task> ProcessHearts() - { - using var c = GetConnection(); - return await c.QueryAsync("CALL process_hearts();"); - } - - public async Task MarkHeartAttackReported(ulong id) - { - using var c = GetConnection(); - await c.ExecuteAsync("UPDATE heartevent SET Status = 'reported' WHERE HeartEventId = @id", new { id }); - } - - #region authorization - - public async Task RemoveSession(string sessionId) - { - using var c = GetConnection(); - await c.ExecuteAsync("DELETE FROM usersession WHERE SessionId = @sessionId", new {sessionId}); - } - public async Task GetSession(string sessionId) - { - using var c = GetConnection(); - return await c.QuerySingleOrDefaultAsync( - "SELECT * FROM usersession WHERE SessionId = @sessionId", - new { sessionId }); - } - - public async Task UpdateSession(string sessionId, byte[] data, DateTime expiryTs) - { - using var c = GetConnection(); - await c.ExecuteAsync(@" - UPDATE usersession SET - SessionData = @data, - ExpiryTs = @expiryTs - WHERE SessionId = @sessionId;", - new { sessionId, data, expiryTs }); - } - - public async Task CreateSession(string sessionId, byte[] data, DateTime expiryTs) - { - using var c = GetConnection(); - await c.ExecuteAsync(@" - INSERT INTO usersession - (SessionId, SessionData, ExpiryTs) - VALUES - (@sessionId, @data, @expiryTs);", - new { sessionId, data, expiryTs }); - return sessionId; - } - - #endregion - - public Db(IOptionsMonitor cfg) - { - Config = cfg; - } - IOptionsMonitor Config { get; } - public MySqlConnection GetConnection() => new(Config.CurrentValue.DefaultConnection); -} - diff --git a/JetHerald/Services/DbContext.cs b/JetHerald/Services/DbContext.cs new file mode 100644 index 0000000..47e836a --- /dev/null +++ b/JetHerald/Services/DbContext.cs @@ -0,0 +1,238 @@ +using System.Data; +using System.Threading; +using System.ComponentModel; +using MySql.Data.MySqlClient; +using Dapper.Transaction; +using JetHerald.Options; +using JetHerald.Contracts; + +namespace JetHerald.Services; + +public class Db +{ + public Db(IOptionsMonitor cfg) + { + Config = cfg; + } + IOptionsMonitor Config { get; } + MySqlConnection GetConnection() => new(Config.CurrentValue.DefaultConnection); + public async Task GetContext( + IsolationLevel lvl = IsolationLevel.RepeatableRead, + CancellationToken token = default) + { + var conn = GetConnection(); + if (conn.State != ConnectionState.Open) + await conn.OpenAsync(); + + var tran = await conn.BeginTransactionAsync(lvl, token); + return new DbContext(tran); + } +} + +public class DbContext : IDisposable +{ + [EditorBrowsable(EditorBrowsableState.Never)] + public DbContext(IDbTransaction tran) + { + Tran = tran; + Conn = Tran.Connection; + } + + IDbConnection Conn; + IDbTransaction Tran; + + public void Commit() => Tran.Commit(); + public void Dispose() + { + Tran.Dispose(); + Conn.Dispose(); + } + public Task> GetTopicsForUser(uint userId) + => Tran.QueryAsync( + " SELECT * FROM topic WHERE CreatorId = @userId", + new { userId }); + public Task> GetPlans() + => Tran.QueryAsync("SELECT * FROM plan"); + public Task> GetRoles() + => Tran.QueryAsync("SELECT * FROM role"); + public Task> GetInvites() + => Tran.QueryAsync("SELECT * FROM userinvite"); + + public Task> GetHeartsForUser(uint userId) + => Tran.QueryAsync( + " SELECT h.* FROM heart h JOIN topic USING (TopicId) WHERE CreatorId = @userId", + new { userId }); + + public Task CreateUserInvite(uint planId, uint roleId, string inviteCode) + => Tran.ExecuteAsync(@" + INSERT INTO userinvite + ( PlanId, RoleId, InviteCode) + VALUES + (@planId, @roleId, @inviteCode)", + new { planId, roleId, inviteCode }); + + public Task GetTopic(string name) + => Tran.QuerySingleOrDefaultAsync( + "SELECT * FROM topic WHERE Name = @name", + new { name }); + + public Task DeleteTopic(string name, uint userId) + => Tran.ExecuteAsync( + " DELETE FROM topic WHERE Name = @name AND CreatorId = @userId", + new { name, userId }); + + public Task GetTopicForSub(string token, NamespacedId sub) + => Tran.QuerySingleOrDefaultAsync( + " SELECT t.*, ts.Sub " + + " FROM topic t " + + " LEFT JOIN topic_sub ts ON t.TopicId = ts.TopicId AND ts.Sub = @sub " + + " WHERE ReadToken = @token", + new { token, sub }); + + public Task> GetHeartsForTopic(uint topicId) + => Tran.QueryAsync( + " SELECT * FROM heart WHERE TopicId = @topicId", + new { topicId }); + public Task GetUser(string login) + => Tran.QuerySingleOrDefaultAsync(@" + SELECT u.*, up.*, ur.* + FROM user u + JOIN plan up ON u.PlanId = up.PlanId + JOIN role ur ON u.RoleId = ur.RoleId + WHERE u.Login = @login;", + new { login }); + + public async Task CreateTopic(uint user, string name, string descr) + { + var topicsCount = await Tran.QuerySingleAsync( + " SELECT COUNT(*) " + + " FROM user u " + + " LEFT JOIN topic t ON t.CreatorId = u.UserId " + + " WHERE u.UserId = @user", + new { user } + ); + + var planTopicsCount = await Tran.QuerySingleAsync( + " SELECT p.MaxTopics " + + " FROM user u " + + " LEFT JOIN plan p ON p.PlanId = u.PlanId " + + " WHERE u.UserId = @user", + new { user } + ); + + if (topicsCount >= planTopicsCount) return null; + + var topic = await Tran.QuerySingleOrDefaultAsync( + " INSERT INTO topic " + + " ( CreatorId, Name, Description, ReadToken, WriteToken) " + + " VALUES " + + " (@CreatorId, @Name, @Description, @ReadToken, @WriteToken); " + + " SELECT * FROM topic WHERE TopicId = LAST_INSERT_ID(); ", + new Topic + { + CreatorId = user, + Name = name, + Description = descr, + ReadToken = TokenHelper.GetToken(), + WriteToken = TokenHelper.GetToken() + }); + return topic; + } + + public async Task RegisterUser(User user) + { + _ = await Tran.QuerySingleOrDefaultAsync(@" + INSERT INTO user + ( Login, Name, PasswordHash, PasswordSalt, HashType, PlanId, RoleId) + VALUES + (@Login, @Name, @PasswordHash, @PasswordSalt, @HashType, @PlanId, @RoleId);", + param:user); + return await GetUser(user.Login); + } + + public Task RedeemInvite(uint inviteId, uint userId) + => Tran.ExecuteAsync( + @"UPDATE userinvite SET RedeemedBy = @userId WHERE UserInviteId = @inviteId", + new { inviteId, userId }); + + public Task GetInviteByCode(string inviteCode) + => Tran.QuerySingleOrDefaultAsync( + " SELECT * FROM userinvite " + + " WHERE InviteCode = @inviteCode " + + " AND RedeemedBy IS NULL ", + new { inviteCode }); + + public Task> GetSubsForTopic(uint topicId) + => Tran.QueryAsync( + " SELECT Sub " + + " FROM topic_sub " + + " WHERE TopicId = @topicid", + new { topicId }); + + public Task> GetTopicsForSub(NamespacedId sub) + => Tran.QueryAsync( + " SELECT t.*" + + " FROM topic_sub ts" + + " JOIN topic t USING (TopicId)" + + " WHERE ts.Sub = @sub", + new { sub }); + + public Task CreateSubscription(uint topicId, NamespacedId sub) + => Tran.ExecuteAsync( + " INSERT INTO topic_sub" + + " (TopicId, Sub)" + + " VALUES" + + " (@topicId, @sub)", + new { topicId, sub }); + + public Task RemoveSubscription(string topicName, NamespacedId sub) + => Tran.ExecuteAsync( + " DELETE ts " + + " FROM topic_sub ts" + + " JOIN topic t USING (TopicId) " + + " WHERE t.Name = @topicName AND ts.Sub = @sub;", + new { topicName, sub }); + + + public Task ReportHeartbeat(uint topicId, string heart, int timeoutSeconds) + => Tran.QueryFirstAsync( + @"CALL report_heartbeat(@topicId, @heart, @timeoutSeconds);", + new { topicId, heart, timeoutSeconds }); + + public Task> ProcessHearts() + => Tran.QueryAsync("CALL process_hearts();"); + + public Task MarkHeartAttackReported(ulong id) + => Tran.ExecuteAsync("UPDATE heartevent SET Status = 'reported' WHERE HeartEventId = @id", new { id }); + + #region TicketStore + + public Task RemoveSession(string sessionId) + => Tran.ExecuteAsync("DELETE FROM usersession WHERE SessionId = @sessionId", new { sessionId }); + public Task GetSession(string sessionId) + => Tran.QuerySingleOrDefaultAsync( + "SELECT * FROM usersession WHERE SessionId = @sessionId", + new { sessionId }); + + public Task UpdateSession(string sessionId, byte[] data, DateTime expiryTs) + => Tran.ExecuteAsync(@" + UPDATE usersession SET + SessionData = @data, + ExpiryTs = @expiryTs + WHERE SessionId = @sessionId;", + new { sessionId, data, expiryTs }); + + public async Task CreateSession(string sessionId, byte[] data, DateTime expiryTs) + { + await Tran.ExecuteAsync(@" + INSERT INTO usersession + (SessionId, SessionData, ExpiryTs) + VALUES + (@sessionId, @data, @expiryTs);", + new { sessionId, data, expiryTs }); + return sessionId; + } + + #endregion +} + diff --git a/JetHerald/Services/DiscordCommands.cs b/JetHerald/Services/DiscordCommands.cs index 69b0155..c27a007 100644 --- a/JetHerald/Services/DiscordCommands.cs +++ b/JetHerald/Services/DiscordCommands.cs @@ -21,7 +21,8 @@ public class DiscordCommands : BaseCommandModule _ = ctx.TriggerTypingAsync(); var chat = NamespacedId.Discord(ctx.Channel.Id); - var topic = await Db.GetTopicForSub(token, chat); + using var dbctx = await Db.GetContext(); + var topic = await dbctx.GetTopicForSub(token, chat); if (topic == null) await ctx.RespondAsync("topic not found"); @@ -31,7 +32,8 @@ public class DiscordCommands : BaseCommandModule await ctx.RespondAsync("token mismatch"); else { - await Db.CreateSubscription(topic.TopicId, chat); + await dbctx.CreateSubscription(topic.TopicId, chat); + dbctx.Commit(); await ctx.RespondAsync($"subscribed to {topic.Name}"); } } @@ -46,8 +48,8 @@ public class DiscordCommands : BaseCommandModule ) { _ = ctx.TriggerTypingAsync(); - - int affected = await Db.RemoveSubscription(name, NamespacedId.Discord(ctx.Channel.Id)); + using var dbctx = await Db.GetContext(); + int affected = await dbctx.RemoveSubscription(name, NamespacedId.Discord(ctx.Channel.Id)); if (affected >= 1) await ctx.RespondAsync($"unsubscribed from {name}"); else diff --git a/JetHerald/Services/HeartMonitor.cs b/JetHerald/Services/HeartMonitor.cs index 882efce..4a45b9a 100644 --- a/JetHerald/Services/HeartMonitor.cs +++ b/JetHerald/Services/HeartMonitor.cs @@ -24,14 +24,15 @@ public class HeartMonitor : BackgroundService await Task.Delay(1000 * 10, token); try { - var attacks = await Db.ProcessHearts(); + using var ctx = await Db.GetContext(); + var attacks = await ctx.ProcessHearts(); foreach (var a in attacks) { await Herald.BroadcastMessageRaw( a.TopicId, $"!{a.Description}!:\nHeart \"{a.Heart}\" stopped beating at {a.CreateTs:O}"); - await Db.MarkHeartAttackReported(a.HeartEventId); + await ctx.MarkHeartAttackReported(a.HeartEventId); if (token.IsCancellationRequested) return; diff --git a/JetHerald/Services/JetHeraldBot.cs b/JetHerald/Services/JetHeraldBot.cs index 972b89e..33864ac 100644 --- a/JetHerald/Services/JetHeraldBot.cs +++ b/JetHerald/Services/JetHeraldBot.cs @@ -44,7 +44,9 @@ public partial class JetHeraldBot : IHostedService public async Task BroadcastMessageRaw(uint topicId, string formatted) { - var chatIds = await Db.GetSubsForTopic(topicId); + IEnumerable chatIds; + using (var ctx = await Db.GetContext()) + chatIds = await ctx.GetSubsForTopic(topicId); foreach (var c in chatIds) await SendMessageRaw(c, formatted); } diff --git a/JetHerald/Services/TicketStore.cs b/JetHerald/Services/TicketStore.cs index b538c00..a924062 100644 --- a/JetHerald/Services/TicketStore.cs +++ b/JetHerald/Services/TicketStore.cs @@ -13,29 +13,40 @@ public class JetHeraldTicketStore : ITicketStore Db = db; Cfg = cfg; } - public Task RemoveAsync(string key) - => Db.RemoveSession(key); - - public Task RenewAsync(string key, AuthenticationTicket ticket) - => Db.UpdateSession( + public async Task RemoveAsync(string key) + { + using var ctx = await Db.GetContext(); + await ctx.RemoveSession(key); + ctx.Commit(); + } + public async Task RenewAsync(string key, AuthenticationTicket ticket) + { + using var ctx = await Db.GetContext(); + await ctx.UpdateSession( key, TicketSerializer.Default.Serialize(ticket), ticket.Properties.ExpiresUtc.Value.DateTime); + ctx.Commit(); + } public async Task RetrieveAsync(string key) { - var userSession = await Db.GetSession(key); + using var ctx = await Db.GetContext(); + var userSession = await ctx.GetSession(key); return TicketSerializer.Default.Deserialize(userSession.SessionData); } - public Task StoreAsync(AuthenticationTicket ticket) + public async Task StoreAsync(AuthenticationTicket ticket) { var cfg = Cfg.CurrentValue; var bytes = RandomNumberGenerator.GetBytes(cfg.TicketIdLengthBytes); var key = Convert.ToBase64String(bytes); - return Db.CreateSession( + using var ctx = await Db.GetContext(); + await ctx.CreateSession( key, TicketSerializer.Default.Serialize(ticket), ticket.Properties.ExpiresUtc.Value.DateTime); + ctx.Commit(); + return key; } }