Source code for isek.embedding.openai_embedding

import os
from typing import Optional, List # Changed to List for Python 3.9+ type hinting

from openai import OpenAI # Assuming this is the official openai package
from openai.types import Embedding # For type hinting the result data

from isek.embedding.abstract_embedding import AbstractEmbedding
from isek.util.logger import logger # Assuming logger is configured
from isek.util.tools import split_list # Assuming this utility function exists


[docs] class OpenAIEmbedding(AbstractEmbedding): """ An implementation of :class:`~isek.embedding.abstract_embedding.AbstractEmbedding` that uses OpenAI's API to generate text embeddings. This class connects to the OpenAI API (or a compatible endpoint) to convert text data into numerical vector representations using specified OpenAI embedding models. """ def __init__( self, dim: Optional[int] = None, # Defaulting dim as it might be model-specific model_name: str = "text-embedding-3-small", api_key: Optional[str] = None, base_url: Optional[str] = None ): """ Initializes the OpenAIEmbedding client. :param dim: The expected dimensionality of the embeddings. For some OpenAI models, this can be specified (e.g., text-embedding-3-small). If `None`, the model's default dimensionality will be used. The `dim` parameter in the `AbstractEmbedding` superclass is initialized with this value. :type dim: typing.Optional[int] :param model_name: The name of the OpenAI embedding model to use (e.g., "text-embedding-3-small", "text-embedding-ada-002"). Defaults to "text-embedding-3-small". :type model_name: str :param api_key: The OpenAI API key. If not provided, it will attempt to use the `OPENAI_API_KEY` environment variable. :type api_key: typing.Optional[str] :param base_url: The base URL for the OpenAI API. Useful for proxying requests or using compatible non-OpenAI endpoints. If `None`, the default OpenAI API URL is used. :type base_url: typing.Optional[str] """ super().__init__(dim) self.model_name: str = model_name self.client: OpenAI = OpenAI( base_url=base_url, api_key=api_key or os.environ.get("OPENAI_API_KEY") # Common practice to fallback to env var ) logger.info(f"OpenAIEmbedding initialized with model: {self.model_name}, dim: {self.dim}")
[docs] def embedding(self, datas: List[str]) -> List[List[float]]: """ Generates embeddings for a list of text strings using the configured OpenAI model. The input data is split into chunks to respect API limits (e.g., batch size). The `dim` parameter passed during initialization might be used by some models to specify the output dimensionality. :param datas: A list of text strings to be embedded. :type datas: typing.List[str] :return: A list of embedding vectors. Each inner list is a list of floats representing the embedding for the corresponding input string. :rtype: typing.List[typing.List[float]] :raises openai.APIError: If the OpenAI API returns an error. """ if not datas: return [] # The OpenAI API documentation suggests a max batch size for embeddings, # e.g., 2048 for text-embedding-ada-002. `split_list(datas, 16)` seems # to use a very small batch size (16). This might be intentional or # could be increased for efficiency if the model and API allow. # Let's assume `split_list` handles this appropriately. # For models like text-embedding-3-small, the `dimensions` parameter can be passed. embedding_params = {"input": [], "model": self.model_name} if self.dim is not None and "text-embedding-3" in self.model_name: # Check if model supports dimensions embedding_params["dimensions"] = self.dim all_embeddings: List[List[float]] = [] # Assuming split_list splits `datas` into sub-lists for batching. # The original `split_list(datas, 16)` suggests a batch size of 16. # OpenAI's Python library handles batching internally for some operations, # but for embeddings, you typically pass a list of strings. # The API limit on the number of input strings per request is 2048. # Let's adjust `split_list` or the batching logic if `16` is too small. # For simplicity, using the provided split_list and batch size. data_chunks: List[List[str]] = split_list(datas, 2048) # Max batch size for most models is 2048 inputs for chunk in data_chunks: if not chunk: # Skip empty chunks continue try: current_params = embedding_params.copy() current_params["input"] = chunk logger.debug(f"Requesting embeddings for {len(chunk)} texts with model {self.model_name} " f"{f'and dim {self.dim}' if 'dimensions' in current_params else ''}.") # The `client.embeddings.create` method takes `input`, `model`, and optionally `dimensions`. response = self.client.embeddings.create(**current_params) # Process the response # Each item in response.data is an `Embedding` object which has an `embedding` attribute (list of floats) chunk_embeddings: List[List[float]] = [item.embedding for item in response.data] all_embeddings.extend(chunk_embeddings) logger.debug(f"Received {len(chunk_embeddings)} embeddings for the current chunk.") except Exception as e: # Catching general OpenAI API errors or other issues logger.error(f"Error during OpenAI embedding generation for a chunk: {e}") # Decide on error handling: re-raise, return partial, or return empty for error # For now, let's log and continue if possible, or re-raise if fatal. # If one chunk fails, it might be desirable to stop and report. raise # Re-raise the exception to signal failure if len(all_embeddings) != len(datas): logger.warning(f"Mismatch in number of embeddings generated ({len(all_embeddings)}) " f"and number of input texts ({len(datas)}).") # This might indicate an issue with batch processing or API response. return all_embeddings