Unverified Commit ee4cf370 authored by Birte Kristina Friesel's avatar Birte Kristina Friesel
Browse files

changepoint detection: add --pelt-stretch option

parent 60d71b8a
Loading
Loading
Loading
Loading
Loading
+43 −5
Original line number Diff line number Diff line
@@ -161,9 +161,7 @@ class PELT:
            changepoints = changepoints[:-1]

        if self.stretch != 1:
            changepoints = np.array(
                np.around(changepoints / self.stretch), dtype=np.int
            )
            changepoints = np.array(np.around(changepoints / self.stretch), dtype=int)

        return changepoints

@@ -312,9 +310,38 @@ def export_pgf(filename, data, power, smooth_power):
            )


def detect_changepoints(timestamps, trace, num_samples):
def detect_changepoints(timestamps, trace, num_samples, stretch=1):

    if stretch > 1:
        trace = np.interp(
            np.linspace(0, len(trace) - 1, (len(trace) - 1) * stretch + 1),
            np.arange(len(trace)),
            trace,
        )
    elif stretch < -1:
        ds_factor = -stretch
        trace = (
            np.array(
                list(trace)
                + [trace[-1] for i in range(ds_factor - (trace.shape[0] % ds_factor))]
            )
            .reshape(-1, ds_factor)
            .mean(axis=1)
        )

    pelt = PELT(trace, num_samples=num_samples)
    changepoints = pelt.get_changepoints()

    if stretch > 1:
        changepoints = list(
            np.array(np.around(np.array(changepoints) / stretch), dtype=int)
        )
    elif stretch < -1:
        ds_factor = -stretch
        changepoints = list(
            np.array(np.around(np.array(changepoints) * ds_factor), dtype=int)
        )

    prev = 0
    ret = list()
    for cp in changepoints:
@@ -407,6 +434,13 @@ def main():
        type=int,
        help="Perform changepoint detection with FREQ Hz",
    )
    parser.add_argument(
        "--pelt-stretch",
        metavar="MULTIPLIER",
        type=int,
        default=1,
        help="Stretch data for changepoint detection",
    )
    parser.add_argument(
        "--threshold",
        metavar="WATTS",
@@ -658,13 +692,17 @@ def main():

    if args.pelt is not None:
        power_changepoints = detect_changepoints(
            data[1:, 0] * 1e-6, power_from_energy, num_samples=args.pelt
            data[1:, 0] * 1e-6,
            power_from_energy,
            num_samples=args.pelt,
            stretch=args.pelt_stretch,
        )
        print(f"Found {len(power_changepoints)} changepoints for power")
        current_changepoints = detect_changepoints(
            data[1:, 0] * 1e-6,
            power_from_energy / (data[1:, 2] * 1e-3),
            num_samples=args.pelt,
            stretch=args.pelt_stretch,
        )
        print(f"Found {len(current_changepoints)} changepoints for current")