From ddaf3ee7f2691e7696ab18d6e6b83bdcd997d904 Mon Sep 17 00:00:00 2001 From: Alex Page Date: Thu, 25 Jan 2024 18:28:28 -0500 Subject: [PATCH] Switch from openai to async-openai crate --- Cargo.lock | 119 +++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/main.rs | 4 ++ src/personality.rs | 86 +++++++++++++++++--------------- 4 files changed, 171 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5e485af..1dc9c91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,40 @@ dependencies = [ "serde", ] +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + +[[package]] +name = "async-openai" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b85a7e8b74ef2a2f93f6c360db1778ee86cd62b273407e70f908f477dc93436" +dependencies = [ + "async-convert", + "backoff", + "base64 0.21.7", + "bytes", + "derive_builder", + "futures", + "rand", + "reqwest", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + [[package]] name = "async-trait" version = "0.1.77" @@ -103,6 +137,20 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom", + "instant", + "pin-project-lite", + "rand", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -502,6 +550,7 @@ name = "dj_kitty_cat" version = "0.2.0" dependencies = [ "anyhow", + "async-openai", "openai", "parking_lot", "poise", @@ -550,6 +599,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -679,6 +739,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.30" @@ -929,6 +995,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + [[package]] name = "ipnet" version = "2.9.0" @@ -1045,6 +1120,12 @@ dependencies = [ "triomphe", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -1104,6 +1185,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1550,6 +1641,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rustls 0.21.10", + "rustls-native-certs", "rustls-pemfile", "serde", "serde_json", @@ -1569,6 +1661,22 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "ring" version = "0.16.20" @@ -2429,6 +2537,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-tungstenite" version = "0.18.0" diff --git a/Cargo.toml b/Cargo.toml index 7a3f32a..76e1b74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ rand = "0.8.5" reqwest = "0.11.23" songbird = "0.4.0" thiserror = "1.0.39" +async-openai = "0.18.1" [dependencies.symphonia] version = "0.5.2" diff --git a/src/main.rs b/src/main.rs index adfe770..59e3614 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,6 +51,10 @@ async fn main() -> Result<()> { .with(EnvFilter::from_default_env()) .init(); + env::var("OPENAI_API_KEY") + .expect("Expected an OpenAI API key in the environment: OPENAI_API_KEY"); + + // OLD set_key(env::var("OPENAI_KEY").expect("Expected an OpenAI key in the environment: OPENAI_KEY")); let token = diff --git a/src/personality.rs b/src/personality.rs index 8f2203c..1295326 100644 --- a/src/personality.rs +++ b/src/personality.rs @@ -1,5 +1,9 @@ use anyhow::{Context, Result}; -use openai::chat::{ChatCompletion, ChatCompletionMessage}; +use async_openai::types::{ + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, + CreateChatCompletionRequestArgs, +}; use rand::seq::SliceRandom; const LOADING_MESSAGES: [&str; 20] = [ @@ -53,48 +57,52 @@ pub async fn get_sassy_commentary(title: &str) -> Result { let prompt = format!("Play \"{title}\""); - let completion = ChatCompletion::builder( - "gpt-4", - [ - system - .into_iter() - .map(|s| ChatCompletionMessage { - role: openai::chat::ChatCompletionMessageRole::System, - content: String::from(s), - name: None, - }) - .collect::>(), - vec![ - ChatCompletionMessage { - role: openai::chat::ChatCompletionMessageRole::User, - content: String::from(example_prompt), - name: None, - }, - ChatCompletionMessage { - role: openai::chat::ChatCompletionMessageRole::Assistant, - content: String::from(example_response), - name: None, - }, - ChatCompletionMessage { - role: openai::chat::ChatCompletionMessageRole::User, - content: prompt, - name: None, - }, - ], - ] - .into_iter() - .flatten() - .collect::>(), - ) - .max_tokens(2048_u64) - .create() - .await??; + let client = async_openai::Client::new(); - Ok(completion + let request = CreateChatCompletionRequestArgs::default() + .model("gpt-4") + .messages( + [ + system + .into_iter() + .map(|s| { + ChatCompletionRequestSystemMessageArgs::default() + .content(s) + .build() + .unwrap() + .into() + }) + .collect::>(), + vec![ + ChatCompletionRequestUserMessageArgs::default() + .content(example_prompt) + .build()? + .into(), + ChatCompletionRequestAssistantMessageArgs::default() + .content(example_response) + .build()? + .into(), + ChatCompletionRequestUserMessageArgs::default() + .content(prompt) + .build()? + .into(), + ], + ] + .into_iter() + .flatten() + .collect::>(), + ) + .max_tokens(2048_u16) + .build()?; + + let response = client.chat().create(request).await?; + + response .choices .first() .context("No choices")? .message .content - .clone()) + .clone() + .context("No content") }