Skip to content

Add a new AI provider

Conatus is built to support many AI providers. For now, we support the following:

You can add support for a new AI provider by extending the common interface for AI models, which is BaseAIModel. Note that "AI provider" does not necessarily mean a company with an API. It can also be a local LLM that you want to use.

This short document will guide you through the main steps of extending this base class.

More information in the API docs

This documentation is meant to be an easy overview of the steps involved in adding a new AI provider. For more detailed information, please refer to the reference documentation of the BaseAIModel class.

Use OpenAIModel as an example

OpenAIModel was the first AI model to be added to Conatus. It is a good example of how to extend the base class. Don't hesitate to look at the source code of OpenAIModel to see how it was implemented.

Configuration

The configuration for an AI model is stored in a ModelConfig instance.

This class is designed to be easy to use in conjunction with dictionaries:

  • the user should be able to pass a dictionary to the constructor of the model, and the model should be able to convert that dictionary to a ModelConfig instance. We handle this in the BaseAIModel class.
  • ModelConfig has a to_kwargs method that returns a dictionary with the configuration. This is useful when you deal with API calls that require specific arguments. (More on that below.). It also has an apply_config method, which enables users to have a hierarchical approach of configuration values.

If you think that your AI provider requires some configuration that is not present in the ModelConfig class, you can extend the class to add your own configuration.

You need to implement the default_config method, so that the constructor of your model can start with a default configuration.

Handling missing arguments

Sometimes, you want to distinguish between a missing argument and an argument that is set to None.

Why would you want to do this?

For instance, some AI providers have a timeout argument that enables the user to control how long to wait for a response before giving up.

So, if you have an AI provider that has a default timeout of 60 seconds, you could pass timeout=30 to wait for only half that time. But some AI providers interpret different timeout values specially.

In this case, if you don't pass timeout, the request would use the default value (60 seconds). But if you passed timeout=None, the request would wait indefinitely with no timeout.

To handle this, we use what is called a "not given sentinel". This is an object that is used to represent a missing argument. This is widely used in AI providers SDKs.

To signal that an argument is optional, but can take the value None, we use the OptionalArg type. This is a type that can be either a value or the NotGiven sentinel object.

In the case of the timeout argument, we can define it as follows:

from conatus._types import OptionalArg, CTUS_NOT_GIVEN

timeout: OptionalArg[float] = CTUS_NOT_GIVEN

If your AI provider has an equivalent of a "not given" sentinel, you can use it instead of CTUS_NOT_GIVEN.

from openai import NOT_GIVEN as OPENAI_NOT_GIVEN

timeout: float | type[OPENAI_NOT_GIVEN] | None = OPENAI_NOT_GIVEN

Converting the configuration to a dictionary

The ModelConfig class has a to_kwargs method that converts the configuration to a dictionary. This is important, since you will most likely use a function to call the AI model that will accept options as keyword arguments.

In order to ensure that you're only passing the required arguments to the AI provider, you will need to ensure:

  1. That the arguments in the dictionary are allowed in the API call.
  2. That you are only passing arguments with a value, and not the NotGiven sentinel object.

To accomplish task (1), you can provide a specification to the to_kwargs method. This specification is a TypedDict that describes the arguments that are allowed in the API call. You can find the OpenAI specification in the OpenAIModelCCSpec type (for chat completions) and the OpenAIModelResponseSpec type (for the Responses API).

To accomplish task (2), you will need to provide the "not given" sentinel that you're using in your ModelConfig implementation. This will ensure that if an argument has no value, it will not be passed to the API call.

Initializing the client

We assume that your AI model will need a client to communicate with the AI provider. The __init__ method can optionally take a client argument. If the user does not provide one, you will need to return a default client from the default_client method.

This method is abstract in the BaseAIModel class, so you will need to implement it in your subclass.

What if I don't need a client?

If you don't need a client, you can just return None from the default_client method.

You should also set the api_key_env_variable attribute, so that the model can retrieve the API key from the environment variables. If you don't set this, and the user does not provide an API key, we will raise an error.

You should subclass __del__

If you have a client, you should subclass the __del__ method to close the client when the model is deleted. Otherwise, you might get a variety of nasty memory issues.

Setting default model (or models)

By default, the model will use the model name set in the ModelConfig instance. This means that the default value for the model_name attribute will be used.

Nevertheless, you might want to set different default models for different model types. The type alias ModelType shows the different model types that you can use: reasoning, execution, computer_use, and chat.

You can override the default model name for a specific model type by implementing the default_model_name method.

Calling the AI provider

You will need to implement wrappers around the AI provider's SDK.

We expect four methods to be implemented:

  1. call, which is a synchronous wrapper around the AI provider's SDK's generation capabilities.
  2. acall, which is the asynchronous version of the call method.
  3. call_stream, which is a synchronous wrapper around the AI provider's SDK's streaming capabilities.
  4. acall_stream, which is the asynchronous version of the call_stream method.

In practice, you might want to implement only the asynchronous versions of the methods, and derive the synchronous versions from the asynchronous ones through asyncio.run.

Note that these functions expect a AIPrompt instance as input, and return an AIResponse instance. While you're free to implement the logic to convert the AIPrompt instance to the AI provider's format however you want, we recommend that you use the scaffolding defined through the prepare_call_args method (see immediately below).

Converting the prompt to the AI provider's format

The prepare_call_args method is a helper method that converts a AIPrompt instance to the AI provider's format. It does so by calling two other helper methods:

It's probably best for you to implement these methods in your subclass. You can look at the implementation of OpenAIModel for inspiration.

The prepare_call_args returns a dataclass named AIModelCallArgs with the following attributes:

Logging

You will note that the BaseAIModel's methods relating to calling has two callback arguments: prompt_log_callback and response_log_callback.

These callbacks are useful for debugging purposes. Because we want to be able to inspect what was actually sent to the AI provider and what was received, and because a lot of errors happen just before or just after the API call, you should implement these callbacks in your code as close to the API call as possible. Once again, OpenAIModel is a good example to follow.

How does it look in practice? The idea is another class (say, BaseAIInterface ) has an object to write logs to a file system. That interface will define a method with all the settings properly configured. What you just need to do is to call that method with the JSON representation of the prompt and the response.

Printing to the terminal

The AIModelPrintingMixin class is a mixin that handles the printing to the terminal. It is used by the BaseAIModel class to print the prompt and the response. This is mostly useful when streaming responses (see below).

There are two methods that are useful to you:

  • print_before_sending , which will print the first message to the user (e.g. "Waiting for response from OpenAI...").
  • clean_after_receiving , which will be called when the response is complete, and cleans up the terminal.

Putting everything together

Conceptually, putting the pieces together looks like this:

Highly simplified, but working

This is a highly simplified example, but it can be run as is. (We test it in our CI/CD.)

import asyncio
import json
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal, TypedDict, Iterable

from conatus import AIPrompt
from conatus._types import ParamType, SimpleDict
from conatus.models.base import BaseAIModel
from conatus.models.config import ModelConfig
from conatus.models.inputs_outputs.messages import (
    AssistantAIMessage,
    IncompleteAssistantAIMessage,
    SystemAIMessage,
)
from conatus.models.inputs_outputs.response import (
    AIResponse,
    IncompleteAIResponse,
)
from conatus.models.inputs_outputs.usage import CompletionUsage
from conatus.models.printing import AIModelPrintingMixin
from conatus.models.names import ModelType, ModelName

################################################################################
# LOGGING FUNCTIONS
# This is a dummy implementation, don't worry about them, since if they are
# passed, it will be by another class.
################################################################################


def write_logs(prompt_as_json: str) -> None:
    json_obj = json.loads(prompt_as_json)
    print(json.dumps(json_obj, indent=2))

################################################################################
# SENTINELS
# These are sentinel objects that are used to indicate that an argument is not
# given. They are used to avoid having to pass `None` to the AI provider.
################################################################################

class MyNotGiven:
    pass

MY_NOT_GIVEN = MyNotGiven()

################################################################################
# AI PROVIDER CALLS
# This is a dummy implementation.
# Note that the provider, here, expects the following arguments:
# - `system_prompt`: The system prompt / instructions
# - `messages`: The messages to send to the AI provider.
# - `model`: The model to use.
# - `thinking_effort`: The thinking effort to use.
# - `max_output_tokens`: The maximum number of tokens to generate.
################################################################################


async def ai_provider_chat_completions_create(
    system_prompt: str,
    messages: list[dict[str, Any]],
    model: str | MyNotGiven = MY_NOT_GIVEN,
    thinking_effort: Literal["low", "medium", "high"] | MyNotGiven = MY_NOT_GIVEN,
    max_output_tokens: int | MyNotGiven = MY_NOT_GIVEN,
    response_format: dict[str, Any] | MyNotGiven = MY_NOT_GIVEN,
) -> dict[str, Any]:
    return {
        "output": [
            {"type": "message", "role": "assistant", "content": "Hi friend!"}
        ],
        "usage": {"prompt_tokens": 10, "output_tokens": 10, "total_tokens": 20},
    }


################################################################################
# OUR CLASSES
# * `MyClient` is the client for the AI provider's client.
# * `MyAIModelConfig` is the configuration for the AI model.
# * `MyAIModelFuncSpec` is the specification for the
#    `ai_provider_chat_completions_create` function.
# * `MyAIModel` is the AI model that you will use.
# * `MyAIModelCallArgs` is the data structure with the main call arguments
#   for the AI model. The types should match your conversion functions
################################################################################


@dataclass
class MyClient:
    api_key: str

    def close(self):
        pass


@dataclass  # (1)!
class MyAIModelConfig(ModelConfig):  # (2)!
    """Additional configuration for the AI model."""

    model_name: str = "default_ai_model"
    thinking_effort: Literal["low", "medium", "high"] | MyNotGiven = (
        MY_NOT_GIVEN
    )


class MyAIModelFuncSpec(TypedDict, total=False):
    """The arguments expected by the `chat_completions.create` function."""

    model: str | MyNotGiven
    thinking_effort: Literal["low", "medium", "high"] | MyNotGiven
    max_output_tokens: int | MyNotGiven
    response_format: dict[str, Any] | MyNotGiven


@dataclass
class MyAIModelCallArgs:
    model_config: ModelConfig
    system_message: str | None
    messages: Iterable[dict[str, Any]]
    tools: Iterable[dict[str, Any]] | None
    output_schema: dict[str, Any] | None
    output_schema_was_converted_to_item_object: bool = False


class MyAIModel(BaseAIModel):
    api_key_env_variable: str = "MY_AI_MODEL_API_KEY"  # (3)!
    """The environment variable that contains the API key."""

    provider: str = "my_ai_model"
    """The provider of the AI model."""

    #############################################################################
    # CLIENT AND CONFIGURATION HANDLING
    #############################################################################

    def default_client(
        self,
        model_config: ModelConfig,
        api_key: str | None,
        **kwargs: ParamType,
    ) -> MyClient:
        # Note that you can retrieve the API key from the model config.
        return MyClient(api_key=model_config.api_key)

    def default_config(self) -> MyAIModelConfig:
        return MyAIModelConfig()

    def default_model_name(self, model_type: ModelType | None) -> ModelName | None:
        # Let's say we only deviate from the default model name for the
        # computer_use model.
        if model_type == "computer_use":
            return "my_ai_model:default_computer_use_model"
        return None

    def __del__(self):
        # Close the client when the model is deleted.
        if self.client is not None:
            self.client.close()

    #############################################################################
    # DUMMY FUNCTIONS
    # You can fill them if you want, but here's a minimal implementation that
    # should work for most cases.
    #############################################################################

    async def acall_stream(
        self,
        prompt: AIPrompt,
        model_config: ModelConfig | SimpleDict | None = None,
        *,
        printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
        prompt_log_callback: Callable[[str], None] | None = None,
        response_log_callback_stream: Callable[[str], None] | None = None,
        **kwargs: ParamType,
    ) -> AIResponse:
        raise NotImplementedError

    def call_stream(
        self,
        prompt: AIPrompt,
        model_config: ModelConfig | SimpleDict | None = None,
        *,
        printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
        prompt_log_callback: Callable[[str], None] | None = None,
        response_log_callback_stream: Callable[[str], None] | None = None,
        **kwargs: ParamType,
    ) -> AIResponse:
        asyncio.run(
            self.acall_stream(
                prompt=prompt,
                model_config=model_config,
                printing_mixin_cls=printing_mixin_cls,
                prompt_log_callback=prompt_log_callback,
                response_log_callback_stream=response_log_callback_stream,
                **kwargs,
            )
        )

    def call(
        self,
        prompt: AIPrompt,
        model_config: ModelConfig | SimpleDict | None = None,
        *,
        printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
        prompt_log_callback: Callable[[str], None] | None = None,
        response_log_callback: Callable[[str], None] | None = None,
        **kwargs: ParamType,
    ) -> AIResponse:
        return asyncio.run(
            self.acall(
                prompt=prompt,
                model_config=model_config,
                printing_mixin_cls=printing_mixin_cls,
                prompt_log_callback=prompt_log_callback,
                response_log_callback=response_log_callback,
                **kwargs,
            )
        )

    #############################################################################
    # CONVERSION FUNCTIONS
    # These functions do not need to be called by the `call` method. They are
    # called by the `prepare_call_args` method.
    #############################################################################

    def convert_messages_to_ai_model_format(
        self,
        prompt: AIPrompt,
        config: ModelConfig,
        only_new_messages: bool = False,
    ) -> list[dict[str, Any]]:
        return [
            {"type": "message", "role": "user", "content": message.all_text}
            for message in prompt.messages
        ]

    def convert_tools_to_ai_model_format(
        self, prompt: AIPrompt, config: ModelConfig
    ) -> list[dict[str, Any]] | None:
        return [{"type": "tool", "name": "my_tool", "description": "My tool"}]

    def convert_system_message_to_ai_model_format(
        self,
        system_message: SystemAIMessage,
        config: ModelConfig,
    ) -> str:
        return "You are a helpful assistant."

    def provider_response_to_ai_response(
        self,
        response: dict[str, Any],
        prompt: AIPrompt,
        model_name: str,
        *,
        return_incomplete: bool = False,
        output_schema_was_converted_to_item_object: bool = False,
    ) -> AIResponse | IncompleteAIResponse:
        if return_incomplete:
            return IncompleteAIResponse(
                prompt=prompt,
                message_received=IncompleteAssistantAIMessage(
                    content=response["output"][0]["content"]
                ),
                usage=CompletionUsage(
                    model_name=model_name,
                    prompt_tokens=response["usage"]["prompt_tokens"],
                    completion_tokens=response["usage"]["output_tokens"],
                    total_tokens=response["usage"]["total_tokens"],
                    usage_was_never_given=False,
                ),
            )
        return AIResponse(
            prompt=prompt,
            message_received=AssistantAIMessage.from_text(
                response["output"][0]["content"]
            ),
            usage=CompletionUsage(
                model_name=model_name,
                prompt_tokens=response["usage"]["prompt_tokens"],
                completion_tokens=response["usage"]["output_tokens"],
                total_tokens=response["usage"]["total_tokens"],
                usage_was_never_given=False,
            ),
            output_schema_was_converted_to_item_object=output_schema_was_converted_to_item_object,
        )

    def convert_output_schema_to_ai_model_format(
        self, prompt: AIPrompt, config: ModelConfig
    ) -> tuple[dict[str, Any] | None, bool]:
        return None, False

    #############################################################################
    # CALL METHODS
    #############################################################################

    async def acall(
        self,
        prompt: AIPrompt,
        model_config: ModelConfig | SimpleDict | None = None,
        *,
        printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
        prompt_log_callback: Callable[[str], None] | None = None,
        response_log_callback: Callable[[str], None] | None = None,
        **kwargs: ParamType,
    ) -> AIResponse:
        # Step 1: Convert the prompt and tools to the AI provider's format
        # You only need to implement the conversion functions, and the
        # `prepare_call_args` method will take care of the rest.
        args = self.prepare_call_args(
            prompt, model_config
        )

        # Step 2: Configure the printing mixin
        printing_mixin = printing_mixin_cls(config=args.model_config)
        printing_mixin.print_before_sending()

        # Step 3: Convert the model config to the AI provider's format
        kwargs = args.model_config.to_kwargs(  # (4)!
            specification=MyAIModelFuncSpec,
            not_given_sentinel=MY_NOT_GIVEN,
            argument_mapping={
                "max_tokens": "max_output_tokens",
                "model_name": "model",
            },
        )

        # Step 4: Log the prompt
        if prompt_log_callback is not None:
            prompt_json = json.dumps(
                {
                    "system": args.system_message,
                    "messages": args.messages,
                    "tools": args.tools,
                    "kwargs": kwargs,
                    "output_schema": args.output_schema,
                }
            )
            prompt_log_callback(prompt_json)  # (5)!

        # Step 5: Call the AI provider
        response = await ai_provider_chat_completions_create(
            system_prompt=args.system_message,
            messages=args.messages,
            response_format=args.output_schema,
            **kwargs,
        )

        # Step 6: Log the response
        if response_log_callback is not None:
            response_json = json.dumps(response)
            response_log_callback(response_json)  # (6)!

        # Step 7: Clean up the terminal
        printing_mixin.clean_after_receiving()

        # Step 8: Convert the response to an AIResponse instance
        return self.provider_response_to_ai_response(
            response,
            prompt,
            args.model_config.model_name,
            output_schema_was_converted_to_item_object=args.output_schema_was_converted_to_item_object,
        )


################################################################################
# MAIN
################################################################################

prompt = AIPrompt("Hello, world!")
model = MyAIModel(api_key="my_api_key")
ai_response = model.call(
    prompt, prompt_log_callback=write_logs, response_log_callback=write_logs
)
  1. It is very important to use the @dataclass decorator, otherwise the attributes that you define will not be properly initialized.
  2. Note that by subclassing ModelConfig, you get the following attributes:
  3. For now, we require all subclasses to define a api_key_env_variable attribute, so that we can retrieve it in case the user does not provide an API key.
  4. This is where you can pass the spec of the AI provider's API, as well as the argument mapping. Here, this argument mapping means that, because your AI provider expects a max_output_tokens argument, and the max_tokens argument is already taken by the ModelConfig class, we will map the max_tokens argument to the max_output_tokens argument.
  5. This should print this

    {
      "system": null,
      "messages": [
        {
          "type": "message",
          "role": "user",
          "content": "Hello, world!"
        }
      ],
      "tools": [
        {
          "type": "tool",
          "name": "my_tool",
          "description": "My tool"
        }
      ],
      "kwargs": {
        "model": "default_ai_model"
      },
      "output_schema": null
    }
    
  6. This should print this

    {
      "output": [
        {
          "type": "message",
          "role": "assistant",
          "content": "Hi friend!"
        }
      ],
      "usage": {
        "prompt_tokens": 10,
        "output_tokens": 10,
        "total_tokens": 20
      }
    }
    

Extra features

Handling incomplete AI responses / Streaming

If you stream the response from the AI provider, and you want to handle the incomplete response (e.g. by displaying it to the user), you can do so by leveraging the IncompleteAIResponse class.

This class has a few interesting properties:

  1. It can be initialized with the messages sent to the AI provider.
  2. It can be added to another IncompleteAIResponse instance.
  3. When finished, it can be converted to an AIResponse instance through its complete method.

This means that you can do the following:

from conatus import AIPrompt
from conatus.models.inputs_outputs.response import IncompleteAIResponse
from conatus.models.inputs_outputs.messages import IncompleteAssistantAIMessage

# First, you define a function that takes the chunk of the AI response
# that you receive from the AI provider, and converts it to an
# `IncompleteAIResponse` instance.
def chunk_to_incomplete_ai_response(chunk: str) -> IncompleteAIResponse: # (1)!
    return IncompleteAIResponse(
        message_received=IncompleteAssistantAIMessage(content=chunk)
    )

# Second, you initialize the `IncompleteAIResponse` instance with the messages
# sent to the AI provider.

prompt = AIPrompt("Hello, world!")
incomplete_response = IncompleteAIResponse(prompt=prompt)

# You can now do something like this:

# response_stream = my_ai_provider.chat.completions.create(stream=True, ...)
# for chunk in response_stream:
#     incomplete_response += chunk_to_incomplete_ai_response(chunk)

# And finally, you can get the complete AI response through the
# `complete` method.
ai_response = incomplete_response.complete()
  1. Here, the chunk is a string, but it could be a more complex object that contains the number of tokens, the finish reason, etc.

Printing incomplete AI responses

In practice, this is really useful to display messages to the user as they are generated. For this, you will want to use the AIModelPrintingMixin class, which handles it for you.

You don't need to subclass this class at all. There are three methods that will be useful to you:

  • print_before_sending , which will print the first message to the user (e.g. "Waiting for response from OpenAI...").
  • write_preview_response , which will be called when a new chunk of the response is received. Note that this method is called with the IncompleteAIResponse instance, so you can access the current state of the response.
  • clean_after_receiving , which will be called when the response is complete, and cleans up the terminal.

Dealing with stateful APIs

The previous example has assumed that your AI provider requires you to pass every message in the conversation every time you call it. In other words, the AI provider makes it your job to take care of keeping track of the message history.

Some AI providers (most importantly, OpenAI's through their Responses API, but also Google's Gemini) offer stateful APIs that keep track of the conversation history for you, and you can subclass BaseAIModel to support these capabilities.

If you want to support an AI provider with that capability, you will need to do the following:

  1. In your call methods, you should pass the prompt's previous_messages_id to prepare_call_args. This will set the value of ModelConfig.only_pass_new_messages to True and the value of ModelConfig.previous_messages_id to the identifier you passed.

  2. In your override of convert_messages_to_ai_model_format , you should ensure that you support the only_new_messages parameter. That parameter will be set to True by prepare_call_args . In this case, you should ensure that you process only the AIPrompt.new_messages attributes.