在neuraxle文档中展示了一个示例,展示了如何在管道中使用存储库进行惰性加载数据,请看以下代码:
from neuraxle.pipeline import Pipeline, MiniBatchSequentialPipelinefrom neuraxle.base import ExecutionContextfrom neuraxle.steps.column_transformer import ColumnTransformerfrom neuraxle.steps.flow import TrainOnlyWrappertraining_data_ids = training_data_repository.get_all_ids()context = ExecutionContext('caching_folder').set_service_locator({ BaseRepository: training_data_repository})pipeline = Pipeline([ ConvertIDsToLoadedData().assert_has_services(BaseRepository), ColumnTransformer([ (range(0, 2), DateToCosineEncoder()), (3, CategoricalEnum(categeories_count=5, starts_at_zero=True)), ]), Normalizer(), TrainOnlyWrapper(DataShuffler()), MiniBatchSequentialPipeline([ Model() ], batch_size=128)]).with_context(context)
然而,文档中没有展示如何实现BaseRepository
和ConvertIDsToLoadedData
类。实现这些类的最佳方式是什么?谁能给出一个示例?
回答:
我没有检查以下代码是否能编译,但它应该看起来像下面这样。如果有人发现需要更改的地方并尝试过编译,请编辑这个回答:
class BaseDataRepository(ABC): @abstractmethod def get_all_ids(self) -> List[int]: pass @abstractmethod def get_data_from_id(self, _id: int) -> object: passclass InMemoryDataRepository(BaseDataRepository): def __init__(self, ids, data): self.ids: List[int] = ids self.data: Dict[int, object] = data def get_all_ids(self) -> List[int]: return list(self.ids) def get_data_from_id(self, _id: int) -> object: return self.data[_id]class ConvertIDsToLoadedData(BaseStep): def _transform_data_container(self, data_container: DataContainer, context: ExecutionContext): repo: BaseDataRepository = context.get_service(BaseDataRepository) ids = data_container.data_inputs # 将数据ID替换为加载后的对象: data_container.data_inputs = [repo.get_data_from_id(_id) for _id in ids] return data_container, contextcontext = ExecutionContext('caching_folder').set_service_locator({ BaseDataRepository: InMemoryDataRepository(ids, data) # 或者在这里插入任何其他继承自`BaseDataRepository`的替换类,当您将数据库更改为真实的数据库(例如:SQL)而不是简单的"InMemory"存根时。 })
有关更新,请查看我为此问题在此处开启的问题:https://github.com/Neuraxio/Neuraxle/issues/421