Source code for isek.llm.openai_model

"""encoding=utf-8"""

import time
import os
import json
from isek.util.logger import logger # Assuming logger is configured
from isek.llm.abstract_model import AbstractModel
from isek.util.tools import function_to_schema, load_json_from_chat_response # Assuming these utilities exist
from typing import Union, List, Optional, Dict, Callable, Any # Added Any
from openai import OpenAI
from openai.types.chat import ChatCompletion # For type hinting the response of `create`

# Define a type alias for message dictionaries for clarity
ChatMessage = Dict[str, str] # e.g., {"role": "user", "content": "Hello"}
ToolSchema = Dict[str, Any] # e.g., the schema for a tool/function

[docs] class OpenAIModel(AbstractModel): """ An implementation of :class:`~isek.llm.abstract_model.AbstractModel` that uses OpenAI's API for chat completions. This class provides methods to interact with OpenAI's chat models (like GPT-3.5, GPT-4) to generate text, JSON objects, and handle tool/function calling. It handles API key and endpoint configuration, request formatting, and retries. """ def __init__( self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None ): """ Initializes the OpenAIModel client. Configuration (model name, API key, base URL) is sourced from parameters, falling back to environment variables (`OPENAI_MODEL_NAME`, `OPENAI_API_KEY`, `OPENAI_BASE_URL`) if parameters are not provided. :param model_name: The name of the OpenAI chat model to use (e.g., "gpt-3.5-turbo", "gpt-4"). If `None`, defaults to the value of `OPENAI_MODEL_NAME` environment variable, or remains `None` if the environment variable is also not set (which might lead to errors if not specified before making API calls). :type model_name: typing.Optional[str] :param api_key: The OpenAI API key. If `None`, defaults to `OPENAI_API_KEY` environment variable. :type api_key: typing.Optional[str] :param base_url: The base URL for the OpenAI API. Useful for proxying or using compatible non-OpenAI endpoints. If `None`, defaults to `OPENAI_BASE_URL` environment variable, or the default OpenAI API URL if the environment variable is also not set. :type base_url: typing.Optional[str] """ super().__init__() self.model_name: Optional[str] = model_name or os.environ.get("OPENAI_MODEL_NAME") # Ensure model_name is set, otherwise API calls will fail. if not self.model_name: logger.warning("OpenAIModel initialized without a model_name. API calls may fail. " "Set it via parameter or OPENAI_MODEL_NAME environment variable.") # Consider raising an error here if a model_name is strictly required at init. # For now, allowing it to be None and potentially fail later. _base_url: Optional[str] = base_url or os.environ.get("OPENAI_BASE_URL") _api_key: Optional[str] = api_key or os.environ.get("OPENAI_API_KEY") self.client: OpenAI = OpenAI(base_url=_base_url, api_key=_api_key) logger.info(f"OpenAIModel initialized with model: {self.model_name}, base_url: {_base_url if _base_url else 'default'}")
[docs] def generate_json( self, prompt: str, system_messages: Optional[List[ChatMessage]] = None, retry: int = 3, check_json_def: Optional[Callable[[Dict[str, Any]], None]] = None ) -> Dict[str, Any]: """ Generates a JSON object from the model based on a prompt. It attempts to parse the model's response content as JSON. Includes retry logic and an optional validation function for the parsed JSON. .. note:: To reliably get JSON output, ensure your prompt explicitly instructs the model to generate JSON. Newer OpenAI models support a `response_format` parameter (e.g., `{"type": "json_object"}`) in the `create` method, which could be integrated here for more robust JSON generation. :param prompt: The user prompt instructing the model to generate JSON. :type prompt: str :param system_messages: An optional list of system messages to prepend to the conversation. :type system_messages: typing.Optional[typing.List[ChatMessage]] :param retry: The number of times to retry the API call if it fails or JSON parsing fails. Defaults to 3. :type retry: int :param check_json_def: An optional callable that takes the parsed JSON dictionary as input and should raise an exception if the JSON is invalid. If `None`, no custom validation is performed beyond basic parsing. :type check_json_def: typing.Optional[typing.Callable[[typing.Dict[str, typing.Any]], None]] :return: A dictionary parsed from the model's JSON response. :rtype: typing.Dict[str, typing.Any] :raises RuntimeError: If the API call or JSON processing fails after all retries. :raises ValueError: If JSON parsing fails and `load_json_from_chat_response` also fails. :raises Exception: If `check_json_def` raises an exception. """ if not self.model_name: raise ValueError("OpenAIModel model_name is not set. Cannot make API calls.") for i in range(retry): try: # For robust JSON, consider adding response_format={"type": "json_object"} # to the `create` call if the model supports it. # This would require modifying the `create` method or passing it as a kwarg. response: ChatCompletion = self.create( messages=[{'role': 'user', 'content': prompt}], systems=system_messages # Example for future: response_format={"type": "json_object"} ) response_content = response.choices[0].message.content if response_content is None: raise ValueError("Model response content is None.") json_result: Dict[str, Any] try: json_result = json.loads(response_content) except json.JSONDecodeError: # Fallback to a custom JSON extraction logic if direct parsing fails json_result = load_json_from_chat_response(response_content) # This must return a dict or raise if check_json_def: check_json_def(json_result) # This function should raise if validation fails return json_result except Exception as e: logger.warning(f"Model [{self.model_name}] generate_json attempt {i+1}/{retry} failed: {e}") if i == retry - 1: # Last attempt logger.error(f"generate_json failed after {retry} retries for model [{self.model_name}].") raise RuntimeError(f"Failed to generate valid JSON after {retry} retries.") from e time.sleep(1 * (i + 1)) # Exponential backoff basic # Should not be reached if retry > 0 raise RuntimeError("generate_json failed after all retries (unexpectedly reached end of loop).")
[docs] def generate_text( self, prompt: str, system_messages: Optional[List[ChatMessage]] = None, retry: int = 3 ) -> str: """ Generates plain text from the model based on a prompt. Includes retry logic for API calls. :param prompt: The user prompt for text generation. :type prompt: str :param system_messages: An optional list of system messages to prepend. :type system_messages: typing.Optional[typing.List[ChatMessage]] :param retry: The number of times to retry the API call if it fails. Defaults to 3. :type retry: int :return: The text content generated by the model. :rtype: str :raises RuntimeError: If the API call fails after all retries. """ if not self.model_name: raise ValueError("OpenAIModel model_name is not set. Cannot make API calls.") for i in range(retry): try: response: ChatCompletion = self.create( messages=[{'role': 'user', 'content': prompt}], systems=system_messages ) response_content = response.choices[0].message.content if response_content is None: # This might happen if the model's generation is stopped early or filters trigger. logger.warning(f"Model [{self.model_name}] generate_text attempt {i+1}/{retry} returned None content.") # Depending on requirements, either treat as error or return empty string. # For now, let's try again. If consistently None, the loop will exhaust. if i == retry - 1: raise ValueError("Model consistently returned None content.") raise InterruptedError("Model returned None content, retrying.") # Custom signal to retry return response_content except InterruptedError: # Catch the signal to retry specifically for None content if i < retry -1: time.sleep(1 * (i + 1)) continue # If it's the last retry and still None, it will fall through to the general exception. except Exception as e: logger.warning(f"Model [{self.model_name}] generate_text attempt {i+1}/{retry} failed: {e}") if i == retry - 1: # Last attempt logger.error(f"generate_text failed after {retry} retries for model [{self.model_name}].") raise RuntimeError(f"Failed to generate text after {retry} retries.") from e time.sleep(1 * (i + 1)) # Exponential backoff basic # Should not be reached if retry > 0 raise RuntimeError("generate_text failed after all retries (unexpectedly reached end of loop).")
[docs] def create( self, messages: List[ChatMessage], systems: Optional[List[ChatMessage]] = None, tool_schemas: Optional[List[ToolSchema]] = None, **kwargs: Any # Allow passing other ChatCompletion.create parameters ) -> ChatCompletion: """ Creates a chat completion using the OpenAI API. This is the core method for interacting with the chat model. It can handle system messages, user/assistant messages, and tool/function schemas. :param messages: A list of message objects, where each object has a "role" (e.g., "user", "assistant", "tool") and "content". :type messages: typing.List[ChatMessage] :param systems: An optional list of system message objects. These are typically prepended to the `messages` list. :type systems: typing.Optional[typing.List[ChatMessage]] :param tool_schemas: An optional list of tool schemas that the model can choose to call. :type tool_schemas: typing.Optional[typing.List[ToolSchema]] :param kwargs: Additional keyword arguments to pass directly to the `openai.chat.completions.create` method (e.g., `temperature`, `max_tokens`, `response_format`). :type kwargs: typing.Any :return: The raw ChatCompletion object from the OpenAI API. :rtype: openai.types.chat.ChatCompletion :raises openai.APIError: If the OpenAI API returns an error. :raises ValueError: If `model_name` is not set. """ if not self.model_name: raise ValueError("OpenAIModel model_name is not set. Cannot make API calls.") # Prepend system messages if provided final_messages: List[ChatMessage] = (systems if systems else []) + messages # Filter out None for tool_schemas if it's explicitly passed as None api_tools = tool_schemas if tool_schemas is not None else None # Pass None if empty, not an empty list request_params = { "model": self.model_name, "messages": final_messages, } if api_tools: request_params["tools"] = api_tools # Merge any additional kwargs request_params.update(kwargs) logger.debug(f"Request to model [{self.model_name}]: {json.dumps(request_params, indent=2, default=str)}") start_time = time.time() try: response: ChatCompletion = self.client.chat.completions.create(**request_params) cost_seconds = time.time() - start_time # Be cautious logging the full response if it's very large or contains sensitive data. # Log relevant parts like usage and finish_reason. response_summary = { "id": response.id, "model": response.model, "finish_reason": response.choices[0].finish_reason if response.choices else "N/A", "usage": response.usage.model_dump() if response.usage else "N/A" } logger.debug(f"Response from model [{self.model_name}] received in {cost_seconds:.2f}s. Summary: {json.dumps(response_summary)}") return response except Exception as e: logger.error(f"Request to model [{self.model_name}] failed: {e}", exc_info=True) # exc_info for stack trace raise # Re-raise the original exception to be handled by caller or retry logic