Source code for beagle.transformers.base_transformer
import multiprocessing as mp
from abc import ABCMeta, abstractmethod
from queue import Queue
from threading import Thread, current_thread
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Any
from beagle.backends.networkx import NetworkX
from beagle.common import logger
from beagle.datasources import DataSource
from beagle.nodes import Node
_THREAD_COUNT = mp.cpu_count()
# Object to signal termination of processing.
_SENTINEL = object()
if TYPE_CHECKING:
from beagle.backends.base_backend import Backend
[docs]class Transformer(object, metaclass=ABCMeta):
"""Base Transformer class. This class implements a producer/consumer queue
from the datasource to the :py:meth:`transform` method. Producing the list
of nodes is done via :py:meth:`run`
Parameters
----------
datasource : DataSource
The `DataSource` to get events from.
"""
def __init__(self, datasource: DataSource) -> None:
self.count = 0
self._queue: Queue = Queue()
self.datasource = datasource
self.nodes: List[Node] = []
self.errors: Dict[Thread, List[Exception]] = {}
[docs] def to_graph(self, backend: "Backend" = NetworkX, *args, **kwargs) -> Any:
"""Graphs the nodes created by :py:meth:`run`. If no backend is specific,
the default used is NetworkX.
Parameters
----------
backend : [type], optional
[description] (the default is NetworkX, which [default_description])
Returns
-------
[type]
[description]
"""
nodes = self.run()
backend = backend(nodes=nodes, metadata=self.datasource.metadata(), *args, **kwargs)
return backend.graph()
[docs] def run(self) -> List[Node]:
"""Generates the list of nodes from the datasource.
This methods kicks off a producer/consumer queue. The producer grabs events
one by one from the datasource by iterating over the events from the `events`
generator. Each event is then sent to the :py:meth:`transformer` function to be
transformer into one or more `Node` objects.
Returns
-------
List[Node]
All Nodes created from the data source.
"""
logger.debug("Launching transformer")
threads: List[Thread] = []
producer_thread = Thread(target=self._producer_thread)
producer_thread.start()
threads.append(producer_thread)
self.errors[producer_thread] = []
logger.debug("Started producer thread")
consumer_count = _THREAD_COUNT - 1
if consumer_count == 0:
consumer_count = 1
for i in range(consumer_count):
t = Thread(target=self._consumer_thread)
self.errors[t] = []
t.start()
threads.append(t)
logger.debug(f"Started {_THREAD_COUNT-1} consumer threads")
# Wait for the producer to finish
producer_thread.join()
self._queue.join()
# Stop the threads
for i in range(consumer_count):
self._queue.put(_SENTINEL)
for thread in threads:
thread.join()
logger.info(f"Finished processing of events, created {len(self.nodes)} nodes.")
if any([len(x) > 0 for x in self.errors.values()]):
logger.warning(f"Parsing finished with errors.")
logger.debug(self.errors)
return self.nodes
def _producer_thread(self) -> None:
i = 0
for element in self.datasource.events():
self._queue.put(element, block=True)
i += 1
logger.debug(f"Producer Thread {current_thread().name} finished after {i} events")
return
def _consumer_thread(self) -> None:
processed = 0
while True:
event = self._queue.get()
processed += 1
if event is _SENTINEL:
logger.debug(
f"Consumer Thread {current_thread().name} finished after processing {processed} events"
)
return
try:
nodes = self.transform(event)
except Exception as e:
logger.warning(f"Error when parsing event, recieved exception {e}")
logger.debug(event)
self.errors[current_thread()].append(e)
nodes = []
if nodes:
self.nodes += nodes
self._queue.task_done()
[docs] @abstractmethod
def transform(self, event: dict) -> Optional[Iterable[Node]]:
raise NotImplementedError("Transformers must implement transform!")