"""Replicate Provider Module.
This module implements the Replicate language model provider for the Haive framework,
supporting a wide variety of open-source models hosted on Replicate's platform.
The provider handles API key management, model configuration, and safe imports of
the langchain-community package dependencies for Replicate integration.
Examples:
Basic usage:
.. code-block:: python
from haive.core.models.llm.providers.replicate import ReplicateProvider
provider = ReplicateProvider(
model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
temperature=0.7,
max_tokens=1000
)
llm = provider.instantiate()
With custom parameters::
provider = ReplicateProvider(
model="mistralai/mixtral-8x7b-instruct-v0.1",
temperature=0.1,
top_p=0.9,
top_k=50
)
.. autosummary::
:toctree: generated/
ReplicateProvider
"""
from typing import Any
from pydantic import Field, field_validator
from haive.core.models.llm.provider_types import LLMProvider
from haive.core.models.llm.providers.base import BaseLLMProvider, ProviderImportError
[docs]
class ReplicateProvider(BaseLLMProvider):
"""Replicate language model provider configuration.
This provider supports a wide variety of open-source models hosted on Replicate,
including Llama, Mixtral, CodeLlama, and many others with flexible versioning.
Attributes:
provider (LLMProvider): Always LLMProvider.REPLICATE
model (str): The Replicate model to use (owner/name:version format)
temperature (float): Sampling temperature (0.0-5.0)
max_tokens (int): Maximum tokens in response
top_p (float): Nucleus sampling parameter
top_k (int): Top-k sampling parameter
repetition_penalty (float): Repetition penalty parameter
stop_sequences (list): Stop sequences for generation
Examples:
Llama 2 70B model:
.. code-block:: python
provider = ReplicateProvider(
model="meta/llama-2-70b-chat",
temperature=0.7,
max_tokens=2000
)
Mixtral with specific version::
provider = ReplicateProvider(
model="mistralai/mixtral-8x7b-instruct-v0.1:7b3212fbaf88310047672c7764d9f2cce7493d0d80666d899b72af8c0662df7a",
temperature=0.1,
top_p=0.9
)
"""
provider: LLMProvider = Field(
default=LLMProvider.REPLICATE, description="Provider identifier"
)
# Replicate model parameters
temperature: float | None = Field(
default=None,
ge=0,
le=5,
description="Sampling temperature (0.0-5.0 for Replicate)",
)
max_tokens: int | None = Field(
default=None, ge=1, description="Maximum tokens in response"
)
top_p: float | None = Field(
default=None, ge=0, le=1, description="Nucleus sampling parameter"
)
top_k: int | None = Field(
default=None, ge=1, description="Top-k sampling parameter"
)
repetition_penalty: float | None = Field(
default=None, ge=0, le=2, description="Repetition penalty parameter"
)
stop_sequences: list[str] | None = Field(
default=None, description="Stop sequences for generation"
)
def _get_chat_class(self) -> type[Any]:
"""Get the Replicate chat class."""
try:
from langchain_community.chat_models import ChatReplicate
return ChatReplicate
except ImportError as e:
raise ProviderImportError(
provider=self.provider.value,
package=self._get_import_package(),
message="Replicate requires langchain-community. Install with: pip install langchain-community replicate",
) from e
def _get_default_model(self) -> str:
"""Get the default Replicate model."""
return "meta/llama-2-70b-chat"
def _get_import_package(self) -> str:
"""Get the required package name."""
return "langchain-community"
def _get_initialization_params(self, **kwargs) -> dict[str, Any]:
"""Get Replicate-specific initialization parameters."""
params = {
"model": self.model,
**kwargs,
}
# Replicate uses model_kwargs for parameters
model_kwargs = {}
if self.temperature is not None:
model_kwargs["temperature"] = self.temperature
if self.max_tokens is not None:
model_kwargs["max_new_tokens"] = self.max_tokens
if self.top_p is not None:
model_kwargs["top_p"] = self.top_p
if self.top_k is not None:
model_kwargs["top_k"] = self.top_k
if self.repetition_penalty is not None:
model_kwargs["repetition_penalty"] = self.repetition_penalty
if self.stop_sequences is not None:
model_kwargs["stop"] = self.stop_sequences
if model_kwargs:
params["model_kwargs"] = model_kwargs
# Add API key
api_key = self.get_api_key()
if api_key:
params["replicate_api_token"] = api_key
# Add extra params
params.update(self.extra_params or {})
return params
def _get_env_key_name(self) -> str:
"""Get the environment variable name for API key."""
return "REPLICATE_API_TOKEN"
[docs]
@classmethod
def get_models(cls) -> list[str]:
"""Get popular Replicate models."""
return [
"meta/llama-2-70b-chat",
"meta/llama-2-13b-chat",
"meta/llama-2-7b-chat",
"mistralai/mixtral-8x7b-instruct-v0.1",
"meta/codellama-70b-instruct",
"togethercomputer/redpajama-incite-7b-chat",
"replicate/vicuna-13b",
]