aind_behavior_curriculum.trainer.Trainer¶
- class aind_behavior_curriculum.trainer.Trainer(curriculum: TCurriculum)[source]¶
Bases:
Generic[TCurriculum]Trainer class for managing and evaluating curriculum stages and policy transitions, and updating the task parameters based on the active policies and provided metrics. The entry point is the “evaluate” method. .. attribute:: curriculum
The curriculum used by the trainer.
- type:
TCurriculum
- __init__(self, curriculum
TCurriculum):
- _evaluate_stage_transition(curriculum
Curriculum, current_stage: Stage, metrics: TMetrics) -> Optional[Stage]:
- _evaluate_policy_transitions(cls, current_stage
Stage, active_policies: Iterable[Policy], metrics: TMetrics) -> List[Policy]: Evaluates policy transitions for the given current stage and currently active policies, based on the provided metrics.
- evaluate(self, trainer_state
TrainerState, metrics: TMetrics) -> TrainerState:
- get_net_parameter_update(stage_parameters
TaskParameters, stage_policies: Iterable[Policy], curr_metrics: Metrics) -> TaskParameters: Aggregates parameter updates of input stage_policies given current stage_parameters and current metrics.
- _get_unique_policies(policies
List[Policy]) -> List[Policy]: Filters unique policies based on their rule functions and reassembles the Policy objects.
- __init__(curriculum: TCurriculum)[source]¶
Initializes the Trainer with the given curriculum. :param curriculum: The curriculum to be used by the trainer. :type curriculum: TCurriculum
Methods
__init__(curriculum)Initializes the Trainer with the given curriculum.
create_trainer_state(*, stage[, ...])Property that returns a type-aware TrainerState class.
evaluate(trainer_state, metrics)Evaluates the current state of the trainer and updates the stage and policies based on the provided metrics.
get_net_parameter_update(stage_parameters, ...)Aggregates parameter update of input stage_policies given current stage_parameters and current metrics.
Attributes
Property that returns the current curriculum.
- create_trainer_state(*, stage: Stage | None, is_on_curriculum: bool = True, active_policies: Iterable[Policy] | None = None) TrainerState[source]¶
Property that returns a type-aware TrainerState class.
- Returns:
type-aware TrainerState type.
- Return type:
Type[TrainerState]
- property curriculum: TCurriculum[source]¶
Property that returns the current curriculum.
- Returns:
The current curriculum instance.
- Return type:
TCurriculum
- evaluate(trainer_state: TrainerState, metrics: TMetrics) TrainerState[source]¶
Evaluates the current state of the trainer and updates the stage and policies based on the provided metrics. :param trainer_state: The current state of the trainer, including the current stage and active policies. :type trainer_state: TrainerState :param metrics: The metrics used to evaluate the current state and determine transitions. :type metrics: TMetrics
- Returns:
The updated state of the trainer, including the new stage and active policies.
- Return type:
- Raises:
ValueError – If the current stage or active policies are not set in the trainer state.
- static get_net_parameter_update(stage_parameters: TaskParameters, stage_policies: Iterable[Policy], curr_metrics: Metrics) TaskParameters[source]¶
Aggregates parameter update of input stage_policies given current stage_parameters and current metrics.