Source code for haive.core.schema.multi_agent_state_schema
"""Multi-agent state schema for the Haive framework.This module provides a specialized StateSchema for multi-agent architectures,addressing key issues with engine handling, consolidation, and access fromengine nodes. It ensures proper engine access and visibility for sub-agentsin complex agent workflows."""from__future__importannotationsimportloggingfromtypingimportAny,SelffrompydanticimportField,create_model,model_validatorfromhaive.core.schema.state_schemaimportStateSchemalogger=logging.getLogger(__name__)
[docs]classMultiAgentStateSchema(StateSchema):"""Enhanced StateSchema for multi-agent architectures. This class extends the base StateSchema with features specifically designed for multi-agent scenarios, solving common issues with engine handling and access in nested agent structures. It ensures that engines are properly accessible to EngineNodeConfig via the state.engines dictionary. Key Features: - Automatic engines field creation and population - Consolidation of engines from sub-agents - Engine visibility for engine nodes - Compatibility with EngineNodeConfig._get_engine() This schema should be used as the base class for states in multi-agent architectures to ensure proper engine access and visibility. """engines:dict[str,Any]=Field(default_factory=dict,description="Dictionary of engines accessible to nodes")
[docs]@model_validator(mode="after")defpopulate_engines_dict(self)->Self:"""Populate the engines dictionary with all available engines. This validator runs after the model is created and: 1. Collects engines from individual fields 2. Collects engines from class-level .engines 3. Collects engines from sub-agents if present 4. Consolidates all engines into the state.engines dictionary """logger.debug(f"Populating engines dict for {self.__class__.__name__}")ifnothasattr(self,"engines"):self.engines={}field_engines=self.get_engines()iffield_engines:logger.debug(f"Found {len(field_engines)} engines in instance fields")self.engines.update(field_engines)class_engines=self.__class__.get_all_class_engines()ifclass_engines:logger.debug(f"Found {len(class_engines)} engines at class level")forname,engineinclass_engines.items():ifnamenotinself.engines:self.engines[name]=engineifhasattr(self,"agents")andisinstance(self.agents,dict):logger.debug(f"Found agents dictionary with {len(self.agents)} agents")foragent_name,agentinself.agents.items():ifhasattr(agent,"engine_type"):engine_name=getattr(agent,"name",agent_name)ifengine_namenotinself.engines:logger.debug(f"Adding agent '{engine_name}' to engines dict")self.engines[engine_name]=agentifhasattr(agent,"engines")andisinstance(agent.engines,dict):logger.debug(f"Agent '{agent_name}' has {len(agent.engines)} engines")foreng_name,engineinagent.engines.items():qualified_name=f"{agent_name}.{eng_name}"logger.debug(f"Adding engine '{qualified_name}' from agent '{agent_name}'")self.engines[qualified_name]=engineifeng_namenotinself.engines:self.engines[eng_name]=enginelogger.debug(f"Populated engines dict with {len(self.engines)} total engines")returnself
[docs]@classmethoddeffrom_state_schema(cls,schema_class:type[StateSchema],name:str|None=None)->type[MultiAgentStateSchema]:"""Create a MultiAgentStateSchema from an existing StateSchema class. Args: schema_class: Original StateSchema class to convert name: Optional name for the new schema (defaults to original name with 'Multi' prefix) Returns: A new MultiAgentStateSchema subclass with all fields and behaviors from the original """ifnameisNone:name=f"Multi{schema_class.__name__}"field_defs={}forfield_name,field_infoinschema_class.model_fields.items():iffield_name=="engines":continueiffield_name.startswith("__"):continuefield_defs[field_name]=(field_info.annotation,field_info)multi_schema=create_model(name,__base__=cls,**field_defs)ifhasattr(schema_class,"__shared_fields__"):multi_schema.__shared_fields__=list(schema_class.__shared_fields__)ifhasattr(schema_class,"__serializable_reducers__"):multi_schema.__serializable_reducers__=dict(schema_class.__serializable_reducers__)ifhasattr(schema_class,"__reducer_fields__"):multi_schema.__reducer_fields__=dict(schema_class.__reducer_fields__)ifhasattr(schema_class,"__engine_io_mappings__"):multi_schema.__engine_io_mappings__={k:v.copy()fork,vinschema_class.__engine_io_mappings__.items()}ifhasattr(schema_class,"__input_fields__"):multi_schema.__input_fields__={k:list(v)fork,vinschema_class.__input_fields__.items()}ifhasattr(schema_class,"__output_fields__"):multi_schema.__output_fields__={k:list(v)fork,vinschema_class.__output_fields__.items()}ifhasattr(schema_class,"__structured_models__"):multi_schema.__structured_models__=dict(schema_class.__structured_models__)ifhasattr(schema_class,"__structured_model_fields__"):multi_schema.__structured_model_fields__={k:list(v)fork,vinschema_class.__structured_model_fields__.items()}ifhasattr(schema_class,"engines"):multi_schema.engines=dict(schema_class.engines)returnmulti_schema
classMultiAgentSchemaComposer:"""Utility for creating MultiAgentStateSchema classes. This class provides static methods for creating MultiAgentStateSchema classes from existing schemas or components, ensuring proper engine handling in multi-agent architectures. """@staticmethoddeffrom_schema(schema_class:type[StateSchema],name:str|None=None)->type[MultiAgentStateSchema]:"""Create a MultiAgentStateSchema from an existing StateSchema. Args: schema_class: Original StateSchema to convert name: Optional name for the new schema Returns: A new MultiAgentStateSchema class """returnMultiAgentStateSchema.from_state_schema(schema_class,name)@staticmethoddeffrom_components(components:list[Any],name:str="MultiAgentSchema")->type[MultiAgentStateSchema]:"""Create a MultiAgentStateSchema from components. Args: components: List of components to extract fields from name: Name for the schema class Returns: A new MultiAgentStateSchema class """fromhaive.core.schema.schema_composerimportSchemaComposercomposer=SchemaComposer(name=name)composer.add_fields_from_components(components)base_schema=composer.build()returnMultiAgentStateSchema.from_state_schema(base_schema,name)