diff --git a/src/commands/ai.ts b/src/commands/ai.ts index 04624a8..507f6a5 100644 --- a/src/commands/ai.ts +++ b/src/commands/ai.ts @@ -160,47 +160,40 @@ export async function preChecks() { "ollamaApi", "flashModel", "thinkingModel", - ] + ]; - let checked = 0; for (const env of envs) { if (!process.env[env]) { - console.error(`[✨ AI | !] ❌ ${env} not set!`) - return false + console.error(`[✨ AI | !] ❌ ${env} not set!`); + return false; } - checked++; } - const ollamaApi = process.env.ollamaApi - if (!ollamaApi) { - console.error("[✨ AI | !] ❌ ollamaApi not set!") - return false - } - let ollamaOk = false + const ollamaApi = process.env.ollamaApi!; + let ollamaOk = false; for (let i = 0; i < 10; i++) { try { - const res = await axios.get(ollamaApi, { timeout: 2000 }) - if (res && res.data && typeof res.data === 'object' && 'ollama' in res.data) { - ollamaOk = true - break - } - if (res && res.status === 200) { - ollamaOk = true - break + const res = await axios.get(ollamaApi, { timeout: 2000 }); + if (res.status === 200) { + ollamaOk = true; + break; } } catch (err) { - await new Promise(resolve => setTimeout(resolve, 1000)) + if (i < 9) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + } } } + if (!ollamaOk) { - console.error("[✨ AI | !] ❌ Ollama API is not responding at ", ollamaApi) - return false + console.error(`[✨ AI | !] ❌ Ollama API is not responding at ${ollamaApi}`); + return false; } - checked++; - console.log(`[✨ AI] Pre-checks passed [${checked}/${envs.length + 1}]`) + + console.log(`[✨ AI] Pre-checks passed.`); const modelCount = models.reduce((acc, model) => acc + model.models.length, 0); console.log(`[✨ AI] Found ${modelCount} models.`); - return true + return true; } function isAxiosError(error: unknown): error is { response?: { data?: { error?: string }, status?: number, statusText?: string }, request?: unknown, message?: string } { @@ -236,18 +229,18 @@ function containsUrls(text: string): boolean { return text.includes('http://') || text.includes('https://') || text.includes('.com') || text.includes('.net') || text.includes('.org') || text.includes('.io') || text.includes('.ai') || text.includes('.dev') } -async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Message, model: string, aiTemperature: number, originalMessage: string, db: NodePgDatabase, userId: string, Strings: ReturnType, showThinking: boolean): Promise<{ success: boolean; response?: string; error?: string }> { +async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Message, model: string, aiTemperature: number, originalMessage: string, db: NodePgDatabase, userId: string, Strings: ReturnType, showThinking: boolean): Promise<{ success: boolean; response?: string; error?: string, messageType?: 'generation' | 'system' }> { if (!ctx.chat) { return { success: false, error: Strings.unexpectedErr.replace("{error}", Strings.ai.noChatFound), }; } - const cleanedModelName = model.replace('hf.co/', ''); + const cleanedModelName = model.includes('/') ? model.split('/').pop()! : model; let status = Strings.ai.statusWaitingRender; let modelHeader = Strings.ai.modelHeader - .replace("{model}", cleanedModelName) - .replace("{temperature}", aiTemperature) + .replace("{model}", `\`${cleanedModelName}\``) + .replace("{temperature}", String(aiTemperature)) .replace("{status}", status) + "\n\n"; const promptCharCount = originalMessage.length; @@ -274,13 +267,10 @@ async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Me } ); let fullResponse = ""; - let thoughts = ""; let lastUpdateCharCount = 0; let sentHeader = false; let firstChunk = true; const stream: NodeJS.ReadableStream = aiResponse.data as any; - let thinkingMessageSent = false; - let finalResponseText = ''; const formatThinkingMessage = (text: string) => { const withPlaceholders = text @@ -319,7 +309,7 @@ async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Me if (firstChunk) { status = Strings.ai.statusWaitingRender; modelHeader = Strings.ai.modelHeader - .replace("{model}", cleanedModelName) + .replace("{model}", `\`${cleanedModelName}\``) .replace("{temperature}", aiTemperature) .replace("{status}", status) + "\n\n"; await rateLimiter.editMessageWithRetry( @@ -353,7 +343,7 @@ async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Me status = Strings.ai.statusRendering; modelHeader = Strings.ai.modelHeader - .replace("{model}", cleanedModelName) + .replace("{model}", `\`${cleanedModelName}\``) .replace("{temperature}", aiTemperature) .replace("{status}", status) + "\n\n"; @@ -382,6 +372,7 @@ async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Me return { success: true, response: patchedResponse, + messageType: 'generation' }; } catch (error: unknown) { const errorMsg = extractAxiosErrorMessage(error); @@ -395,7 +386,7 @@ async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Me ctx.chat!.id, replyGenerating.message_id, undefined, - Strings.ai.pulling.replace("{model}", model), + Strings.ai.pulling.replace("{model}", `\`${cleanedModelName}\``), { parse_mode: 'Markdown' } ); console.log(`[✨ AI] Pulling ${model} from ollama...`); @@ -413,13 +404,15 @@ async function getResponse(prompt: string, ctx: TextContext, replyGenerating: Me console.error("[✨ AI | !] Pull error:", pullMsg); return { success: false, - error: `❌ Something went wrong while pulling ${model}: ${pullMsg}`, + error: `❌ Something went wrong while pulling \`${model}\`: ${pullMsg}`, + messageType: 'system' }; } console.log(`[✨ AI] ${model} pulled successfully`); return { success: true, - response: Strings.ai.pulled.replace("{model}", model), + response: Strings.ai.pulled.replace("{model}", `\`${cleanedModelName}\``), + messageType: 'system' }; } } @@ -435,11 +428,22 @@ async function handleAiReply(ctx: TextContext, model: string, prompt: string, re if (!aiResponse) return; if (!ctx.chat) return; if (aiResponse.success && aiResponse.response) { - const cleanedModelName = model.replace('hf.co/', ''); + if (aiResponse.messageType === 'system') { + await rateLimiter.editMessageWithRetry( + ctx, + ctx.chat.id, + replyGenerating.message_id, + aiResponse.response, + { parse_mode: 'Markdown' } + ); + return; + } + + const cleanedModelName = model.includes('/') ? model.split('/').pop()! : model; const status = Strings.ai.statusComplete; const modelHeader = Strings.ai.modelHeader - .replace("{model}", cleanedModelName) - .replace("{temperature}", aiTemperature) + .replace("{model}", `\`${cleanedModelName}\``) + .replace("{temperature}", String(aiTemperature)) .replace("{status}", status) + "\n\n"; const urlWarning = containsUrls(originalMessage) ? Strings.ai.urlWarning : ''; let finalResponse = aiResponse.response; @@ -541,18 +545,12 @@ export default (bot: Telegraf, db: NodePgDatabase) => { const message = ctx.message.text; const author = ("@" + ctx.from?.username) || ctx.from?.first_name || "Unknown"; - let model: string; - let fixedMsg: string; + const model = command === 'ai' + ? (customAiModel || flash_model) + : (command === 'ask' ? flash_model : thinking_model); - if (command === 'ai') { - model = customAiModel || flash_model; - fixedMsg = message.replace(/^\/ai(@\w+)?\s*/, "").trim(); - logger.logCmdStart(author, command, model); - } else { - model = command === 'ask' ? flash_model : thinking_model; - fixedMsg = message.replace(/^\/(ask|think)(@\w+)?\s*/, "").trim(); - logger.logCmdStart(author, command, model); - } + const fixedMsg = message.replace(new RegExp(`^/${command}(@\\w+)?\\s*`), "").trim(); + logger.logCmdStart(author, command, model); if (!process.env.ollamaApi) { await ctx.reply(Strings.ai.disabled, { parse_mode: 'Markdown', ...(reply_to_message_id && { reply_parameters: { message_id: reply_to_message_id } }) }); @@ -571,7 +569,7 @@ export default (bot: Telegraf, db: NodePgDatabase) => { const task = async () => { const modelLabel = getModelLabelByName(model); - const replyGenerating = await ctx.reply(Strings.ai.askGenerating.replace("{model}", modelLabel), { + const replyGenerating = await ctx.reply(Strings.ai.askGenerating.replace("{model}", `\`${modelLabel}\``), { parse_mode: 'Markdown', ...(reply_to_message_id && { reply_parameters: { message_id: reply_to_message_id } }) });