Add a new AI provider ¶
Conatus is built to support many AI providers. For now, we support the following:
- OpenAI, through
OpenAIModel - Google, through
GoogleAIModel - Anthropic, through
AnthropicAIModel
You can add support for a new AI provider by extending the common interface for
AI models, which is BaseAIModel. Note that
"AI provider" does not necessarily mean a company with an API. It can also be a
local LLM that you want to use.
This short document will guide you through the main steps of extending this base class.
More information in the API docs
This documentation is meant to be an easy overview of the steps involved in
adding a new AI provider. For more detailed information, please refer to the
reference documentation of the
BaseAIModel class.
Use OpenAIModel as an example
OpenAIModel was the first AI
model to be added to Conatus. It is a good example of how to extend the
base class. Don't hesitate to look at the source code of
OpenAIModel to see how it was
implemented.
Configuration ¶
The configuration for an AI model is stored in a ModelConfig
instance.
This class is designed to be easy to use in conjunction with dictionaries:
- the user should be able to pass a dictionary to the constructor of the model,
and the model should be able to convert that dictionary to a
ModelConfiginstance. We handle this in theBaseAIModelclass.
ModelConfighas ato_kwargsmethod that returns a dictionary with the configuration. This is useful when you deal with API calls that require specific arguments. (More on that below.). It also has anapply_configmethod, which enables users to have a hierarchical approach of configuration values.
If you think that your AI provider requires some configuration that is not
present in the ModelConfig class, you can extend the class to add your own
configuration.
You need to implement the default_config
method, so that the
constructor of your model can start with a default configuration.
Handling missing arguments ¶
Sometimes, you want to distinguish between a missing argument and an argument
that is set to None.
Why would you want to do this?
For instance, some AI providers have a timeout argument that enables
the user to control how long to wait for a response before giving up.
So, if you have an AI provider that has a default timeout of 60 seconds,
you could pass timeout=30 to wait for only half that time.
But some AI providers interpret different timeout values specially.
In this case, if you don't pass timeout, the request would use the default
value (60 seconds). But if you passed timeout=None, the request would
wait indefinitely with no timeout.
To handle this, we use what is called a "not given sentinel". This is an object that is used to represent a missing argument. This is widely used in AI providers SDKs.
To signal that an argument is optional, but can take the value None, we use
the OptionalArg type. This is a type that can be
either a value or the NotGiven sentinel object.
In the case of the timeout argument, we can define it as follows:
If your AI provider has an equivalent of a "not given" sentinel, you can use it
instead of CTUS_NOT_GIVEN.
from openai import NOT_GIVEN as OPENAI_NOT_GIVEN
timeout: float | type[OPENAI_NOT_GIVEN] | None = OPENAI_NOT_GIVEN
Converting the configuration to a dictionary ¶
The ModelConfig class has a
to_kwargs method that converts the
configuration to a dictionary. This is important, since you will most likely use
a function to call the AI model that will accept options as keyword arguments.
In order to ensure that you're only passing the required arguments to the AI provider, you will need to ensure:
- That the arguments in the dictionary are allowed in the API call.
- That you are only passing arguments with a value, and not the
NotGivensentinel object.
To accomplish task (1), you can provide a specification to the to_kwargs
method. This specification is a TypedDict that describes
the arguments that are allowed in the API call. You can find the OpenAI
specification in the OpenAIModelCCSpec
type (for chat completions) and the
OpenAIModelResponseSpec type
(for the Responses API).
To accomplish task (2), you will need to provide the "not given" sentinel that
you're using in your ModelConfig
implementation. This will ensure that if an argument has no value, it will not
be passed to the API call.
Initializing the client ¶
We assume that your AI model will need a client to communicate with the AI
provider. The __init__ method can optionally take a client argument. If the
user does not provide one, you will need to return a default client from the
default_client method.
This method is abstract in the BaseAIModel
class, so you will need to implement it in your subclass.
What if I don't need a client?
If you don't need a client, you can just return None from the
default_client method.
You should also set the api_key_env_variable
attribute, so that the
model can retrieve the API key from the environment variables. If you don't set
this, and the user does not provide an API key, we will raise an error.
You should subclass __del__
If you have a client, you should subclass the __del__ method to close the
client when the model is deleted. Otherwise, you might get a variety of
nasty memory issues.
Setting default model (or models) ¶
By default, the model will use the model name set in the ModelConfig
instance. This means that the default value
for the model_name attribute
will be used.
Nevertheless, you might want to set different default models for different model
types. The type alias ModelType shows the
different model types that you can use: reasoning, execution,
computer_use, and chat.
You can override the default model name for a specific model type by
implementing the default_model_name
method.
Calling the AI provider ¶
You will need to implement wrappers around the AI provider's SDK.
We expect four methods to be implemented:
call, which is a synchronous wrapper around the AI provider's SDK's generation capabilities.acall, which is the asynchronous version of thecallmethod.call_stream, which is a synchronous wrapper around the AI provider's SDK's streaming capabilities.acall_stream, which is the asynchronous version of thecall_streammethod.
In practice, you might want to implement only the asynchronous versions of the
methods, and derive the synchronous versions from the asynchronous ones through
asyncio.run.
Note that these functions expect a
AIPrompt instance as input,
and return an AIResponse
instance. While you're free to implement the logic to convert the
AIPrompt instance to the AI
provider's format however you want, we recommend that you use the scaffolding
defined through the
prepare_call_args method
(see immediately below).
Converting the prompt to the AI provider's format ¶
The prepare_call_args
method is a helper method that converts a
AIPrompt instance to the AI
provider's format. It does so by calling two other helper methods:
convert_messages_to_ai_model_format, which converts the messages in theAIPromptinstance to the AI provider's format.convert_tools_to_ai_model_format, which converts the tools in theAIPromptinstance to the AI provider's format.convert_system_message_to_ai_model_format, which converts aSystemAIMessageto the AI providers's format.convert_output_schema_to_ai_model_format, which converts the output schema in theAIPromptinstance to the AI provider's format. Note that this method returns a tuple with the output schema and a boolean indicating whether the schema has been converted from a non-object type to an object type. This is necessary because the AI provider's SDK may require an object type, but the output schema is a string or a list. In this case, the boolean should beTrue, because it will indicate further modifications. If there's no output schema, returnNoneandFalse.
It's probably best for you to implement these methods in your subclass. You can
look at the implementation of
OpenAIModel for inspiration.
The prepare_call_args
returns a dataclass named
AIModelCallArgs with the following
attributes:
model_config: The model configuration.system_message: The system message.messages: The messages.tools: The tools.output_schema: The output schema.output_schema_was_converted_to_item_object: Whether the output schema was converted from a non-object type to an object type.
Logging ¶
You will note that the BaseAIModel's
methods relating to calling has two callback arguments: prompt_log_callback
and response_log_callback.
These callbacks are useful for debugging purposes. Because we want to be able to
inspect what was actually sent to the AI provider and what was received, and
because a lot of errors happen just before or just after the API call, you
should implement these callbacks in your code as close to the API call as
possible. Once again, OpenAIModel is a
good example to follow.
How does it look in practice? The idea is another class (say, BaseAIInterface
) has an object to write
logs to a file system. That interface will define a method with all the settings
properly configured. What you just need to do is to call that method with the
JSON representation of the prompt and the response.
Printing to the terminal ¶
The AIModelPrintingMixin class
is a mixin that handles the printing to the terminal. It is used by the
BaseAIModel class to print the prompt and
the response. This is mostly useful when streaming responses (see
below).
There are two methods that are useful to you:
print_before_sending, which will print the first message to the user (e.g. "Waiting for response from OpenAI...").clean_after_receiving, which will be called when the response is complete, and cleans up the terminal.
Putting everything together ¶
Conceptually, putting the pieces together looks like this:
Highly simplified, but working
This is a highly simplified example, but it can be run as is. (We test it in our CI/CD.)
import asyncio
import json
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal, TypedDict, Iterable
from conatus import AIPrompt
from conatus._types import ParamType, SimpleDict
from conatus.models.base import BaseAIModel
from conatus.models.config import ModelConfig
from conatus.models.inputs_outputs.messages import (
AssistantAIMessage,
IncompleteAssistantAIMessage,
SystemAIMessage,
)
from conatus.models.inputs_outputs.response import (
AIResponse,
IncompleteAIResponse,
)
from conatus.models.inputs_outputs.usage import CompletionUsage
from conatus.models.printing import AIModelPrintingMixin
from conatus.models.names import ModelType, ModelName
################################################################################
# LOGGING FUNCTIONS
# This is a dummy implementation, don't worry about them, since if they are
# passed, it will be by another class.
################################################################################
def write_logs(prompt_as_json: str) -> None:
json_obj = json.loads(prompt_as_json)
print(json.dumps(json_obj, indent=2))
################################################################################
# SENTINELS
# These are sentinel objects that are used to indicate that an argument is not
# given. They are used to avoid having to pass `None` to the AI provider.
################################################################################
class MyNotGiven:
pass
MY_NOT_GIVEN = MyNotGiven()
################################################################################
# AI PROVIDER CALLS
# This is a dummy implementation.
# Note that the provider, here, expects the following arguments:
# - `system_prompt`: The system prompt / instructions
# - `messages`: The messages to send to the AI provider.
# - `model`: The model to use.
# - `thinking_effort`: The thinking effort to use.
# - `max_output_tokens`: The maximum number of tokens to generate.
################################################################################
async def ai_provider_chat_completions_create(
system_prompt: str,
messages: list[dict[str, Any]],
model: str | MyNotGiven = MY_NOT_GIVEN,
thinking_effort: Literal["low", "medium", "high"] | MyNotGiven = MY_NOT_GIVEN,
max_output_tokens: int | MyNotGiven = MY_NOT_GIVEN,
response_format: dict[str, Any] | MyNotGiven = MY_NOT_GIVEN,
) -> dict[str, Any]:
return {
"output": [
{"type": "message", "role": "assistant", "content": "Hi friend!"}
],
"usage": {"prompt_tokens": 10, "output_tokens": 10, "total_tokens": 20},
}
################################################################################
# OUR CLASSES
# * `MyClient` is the client for the AI provider's client.
# * `MyAIModelConfig` is the configuration for the AI model.
# * `MyAIModelFuncSpec` is the specification for the
# `ai_provider_chat_completions_create` function.
# * `MyAIModel` is the AI model that you will use.
# * `MyAIModelCallArgs` is the data structure with the main call arguments
# for the AI model. The types should match your conversion functions
################################################################################
@dataclass
class MyClient:
api_key: str
def close(self):
pass
@dataclass # (1)!
class MyAIModelConfig(ModelConfig): # (2)!
"""Additional configuration for the AI model."""
model_name: str = "default_ai_model"
thinking_effort: Literal["low", "medium", "high"] | MyNotGiven = (
MY_NOT_GIVEN
)
class MyAIModelFuncSpec(TypedDict, total=False):
"""The arguments expected by the `chat_completions.create` function."""
model: str | MyNotGiven
thinking_effort: Literal["low", "medium", "high"] | MyNotGiven
max_output_tokens: int | MyNotGiven
response_format: dict[str, Any] | MyNotGiven
@dataclass
class MyAIModelCallArgs:
model_config: ModelConfig
system_message: str | None
messages: Iterable[dict[str, Any]]
tools: Iterable[dict[str, Any]] | None
output_schema: dict[str, Any] | None
output_schema_was_converted_to_item_object: bool = False
class MyAIModel(BaseAIModel):
api_key_env_variable: str = "MY_AI_MODEL_API_KEY" # (3)!
"""The environment variable that contains the API key."""
provider: str = "my_ai_model"
"""The provider of the AI model."""
#############################################################################
# CLIENT AND CONFIGURATION HANDLING
#############################################################################
def default_client(
self,
model_config: ModelConfig,
api_key: str | None,
**kwargs: ParamType,
) -> MyClient:
# Note that you can retrieve the API key from the model config.
return MyClient(api_key=model_config.api_key)
def default_config(self) -> MyAIModelConfig:
return MyAIModelConfig()
def default_model_name(self, model_type: ModelType | None) -> ModelName | None:
# Let's say we only deviate from the default model name for the
# computer_use model.
if model_type == "computer_use":
return "my_ai_model:default_computer_use_model"
return None
def __del__(self):
# Close the client when the model is deleted.
if self.client is not None:
self.client.close()
#############################################################################
# DUMMY FUNCTIONS
# You can fill them if you want, but here's a minimal implementation that
# should work for most cases.
#############################################################################
async def acall_stream(
self,
prompt: AIPrompt,
model_config: ModelConfig | SimpleDict | None = None,
*,
printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
prompt_log_callback: Callable[[str], None] | None = None,
response_log_callback_stream: Callable[[str], None] | None = None,
**kwargs: ParamType,
) -> AIResponse:
raise NotImplementedError
def call_stream(
self,
prompt: AIPrompt,
model_config: ModelConfig | SimpleDict | None = None,
*,
printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
prompt_log_callback: Callable[[str], None] | None = None,
response_log_callback_stream: Callable[[str], None] | None = None,
**kwargs: ParamType,
) -> AIResponse:
asyncio.run(
self.acall_stream(
prompt=prompt,
model_config=model_config,
printing_mixin_cls=printing_mixin_cls,
prompt_log_callback=prompt_log_callback,
response_log_callback_stream=response_log_callback_stream,
**kwargs,
)
)
def call(
self,
prompt: AIPrompt,
model_config: ModelConfig | SimpleDict | None = None,
*,
printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
prompt_log_callback: Callable[[str], None] | None = None,
response_log_callback: Callable[[str], None] | None = None,
**kwargs: ParamType,
) -> AIResponse:
return asyncio.run(
self.acall(
prompt=prompt,
model_config=model_config,
printing_mixin_cls=printing_mixin_cls,
prompt_log_callback=prompt_log_callback,
response_log_callback=response_log_callback,
**kwargs,
)
)
#############################################################################
# CONVERSION FUNCTIONS
# These functions do not need to be called by the `call` method. They are
# called by the `prepare_call_args` method.
#############################################################################
def convert_messages_to_ai_model_format(
self,
prompt: AIPrompt,
config: ModelConfig,
only_new_messages: bool = False,
) -> list[dict[str, Any]]:
return [
{"type": "message", "role": "user", "content": message.all_text}
for message in prompt.messages
]
def convert_tools_to_ai_model_format(
self, prompt: AIPrompt, config: ModelConfig
) -> list[dict[str, Any]] | None:
return [{"type": "tool", "name": "my_tool", "description": "My tool"}]
def convert_system_message_to_ai_model_format(
self,
system_message: SystemAIMessage,
config: ModelConfig,
) -> str:
return "You are a helpful assistant."
def provider_response_to_ai_response(
self,
response: dict[str, Any],
prompt: AIPrompt,
model_name: str,
*,
return_incomplete: bool = False,
output_schema_was_converted_to_item_object: bool = False,
) -> AIResponse | IncompleteAIResponse:
if return_incomplete:
return IncompleteAIResponse(
prompt=prompt,
message_received=IncompleteAssistantAIMessage(
content=response["output"][0]["content"]
),
usage=CompletionUsage(
model_name=model_name,
prompt_tokens=response["usage"]["prompt_tokens"],
completion_tokens=response["usage"]["output_tokens"],
total_tokens=response["usage"]["total_tokens"],
usage_was_never_given=False,
),
)
return AIResponse(
prompt=prompt,
message_received=AssistantAIMessage.from_text(
response["output"][0]["content"]
),
usage=CompletionUsage(
model_name=model_name,
prompt_tokens=response["usage"]["prompt_tokens"],
completion_tokens=response["usage"]["output_tokens"],
total_tokens=response["usage"]["total_tokens"],
usage_was_never_given=False,
),
output_schema_was_converted_to_item_object=output_schema_was_converted_to_item_object,
)
def convert_output_schema_to_ai_model_format(
self, prompt: AIPrompt, config: ModelConfig
) -> tuple[dict[str, Any] | None, bool]:
return None, False
#############################################################################
# CALL METHODS
#############################################################################
async def acall(
self,
prompt: AIPrompt,
model_config: ModelConfig | SimpleDict | None = None,
*,
printing_mixin_cls: type[AIModelPrintingMixin] = AIModelPrintingMixin,
prompt_log_callback: Callable[[str], None] | None = None,
response_log_callback: Callable[[str], None] | None = None,
**kwargs: ParamType,
) -> AIResponse:
# Step 1: Convert the prompt and tools to the AI provider's format
# You only need to implement the conversion functions, and the
# `prepare_call_args` method will take care of the rest.
args = self.prepare_call_args(
prompt, model_config
)
# Step 2: Configure the printing mixin
printing_mixin = printing_mixin_cls(config=args.model_config)
printing_mixin.print_before_sending()
# Step 3: Convert the model config to the AI provider's format
kwargs = args.model_config.to_kwargs( # (4)!
specification=MyAIModelFuncSpec,
not_given_sentinel=MY_NOT_GIVEN,
argument_mapping={
"max_tokens": "max_output_tokens",
"model_name": "model",
},
)
# Step 4: Log the prompt
if prompt_log_callback is not None:
prompt_json = json.dumps(
{
"system": args.system_message,
"messages": args.messages,
"tools": args.tools,
"kwargs": kwargs,
"output_schema": args.output_schema,
}
)
prompt_log_callback(prompt_json) # (5)!
# Step 5: Call the AI provider
response = await ai_provider_chat_completions_create(
system_prompt=args.system_message,
messages=args.messages,
response_format=args.output_schema,
**kwargs,
)
# Step 6: Log the response
if response_log_callback is not None:
response_json = json.dumps(response)
response_log_callback(response_json) # (6)!
# Step 7: Clean up the terminal
printing_mixin.clean_after_receiving()
# Step 8: Convert the response to an AIResponse instance
return self.provider_response_to_ai_response(
response,
prompt,
args.model_config.model_name,
output_schema_was_converted_to_item_object=args.output_schema_was_converted_to_item_object,
)
################################################################################
# MAIN
################################################################################
prompt = AIPrompt("Hello, world!")
model = MyAIModel(api_key="my_api_key")
ai_response = model.call(
prompt, prompt_log_callback=write_logs, response_log_callback=write_logs
)
- It is very important to use the
@dataclassdecorator, otherwise the attributes that you define will not be properly initialized. - Note that by subclassing
ModelConfig, you get the following attributes:api_key, which is the API key for the AI provider.model_name, which is the name of the model to use.max_tokens, which is the maximum number of tokens to generate.temperature, which is the temperature for the model.computer_use_mode, which is whether to use the computer use mode.use_mock, which is whether to use a mock client.stdout_mode, which is the mode to use for when using theAIModelPrintingMixin.
- For now, we require all subclasses to define a
api_key_env_variableattribute, so that we can retrieve it in case the user does not provide an API key. - This is where you can pass the spec of the AI provider's API, as well as the
argument mapping. Here, this argument mapping means that, because your AI
provider expects a
max_output_tokensargument, and themax_tokensargument is already taken by theModelConfigclass, we will map themax_tokensargument to themax_output_tokensargument. -
This should print this
-
This should print this
Extra features ¶
Handling incomplete AI responses / Streaming ¶
If you stream the response from the AI provider, and you want to handle the
incomplete response (e.g. by displaying it to the user), you can do so by
leveraging the IncompleteAIResponse
class.
This class has a few interesting properties:
- It can be initialized with the messages sent to the AI provider.
- It can be added to another
IncompleteAIResponseinstance. - When finished, it can be converted to an
AIResponseinstance through itscompletemethod.
This means that you can do the following:
from conatus import AIPrompt
from conatus.models.inputs_outputs.response import IncompleteAIResponse
from conatus.models.inputs_outputs.messages import IncompleteAssistantAIMessage
# First, you define a function that takes the chunk of the AI response
# that you receive from the AI provider, and converts it to an
# `IncompleteAIResponse` instance.
def chunk_to_incomplete_ai_response(chunk: str) -> IncompleteAIResponse: # (1)!
return IncompleteAIResponse(
message_received=IncompleteAssistantAIMessage(content=chunk)
)
# Second, you initialize the `IncompleteAIResponse` instance with the messages
# sent to the AI provider.
prompt = AIPrompt("Hello, world!")
incomplete_response = IncompleteAIResponse(prompt=prompt)
# You can now do something like this:
# response_stream = my_ai_provider.chat.completions.create(stream=True, ...)
# for chunk in response_stream:
# incomplete_response += chunk_to_incomplete_ai_response(chunk)
# And finally, you can get the complete AI response through the
# `complete` method.
ai_response = incomplete_response.complete()
- Here, the chunk is a string, but it could be a more complex object that contains the number of tokens, the finish reason, etc.
Printing incomplete AI responses ¶
In practice, this is really useful to display messages to the user as they are
generated. For this, you will want to use the
AIModelPrintingMixin class, which handles it for
you.
You don't need to subclass this class at all. There are three methods that will be useful to you:
print_before_sending, which will print the first message to the user (e.g. "Waiting for response from OpenAI...").write_preview_response, which will be called when a new chunk of the response is received. Note that this method is called with theIncompleteAIResponseinstance, so you can access the current state of the response.clean_after_receiving, which will be called when the response is complete, and cleans up the terminal.
Dealing with stateful APIs ¶
The previous example has assumed that your AI provider requires you to pass every message in the conversation every time you call it. In other words, the AI provider makes it your job to take care of keeping track of the message history.
Some AI providers (most importantly, OpenAI's through their Responses API, but
also Google's Gemini) offer stateful APIs that keep track of the conversation
history for you, and you can subclass BaseAIModel
to support these capabilities.
If you want to support an AI provider with that capability, you will need to do the following:
-
In your call methods, you should pass the prompt's
previous_messages_idtoprepare_call_args. This will set the value ofModelConfig.only_pass_new_messagesto True and the value ofModelConfig.previous_messages_idto the identifier you passed. -
In your override of
convert_messages_to_ai_model_format, you should ensure that you support theonly_new_messagesparameter. That parameter will be set to True byprepare_call_args. In this case, you should ensure that you process only theAIPrompt.new_messagesattributes.