Skip to content

API reference

Models

The :mod:~pyter.models module provides flexible model classes to represent distinct experimental setups.

The :class:AbstractModel base class serves as a template for :class:Model subclasses that represent different possible experimental setups with different inferred quantities of interest.

The core of any :class:Model subclass is the :meth:~pyter.models.AbstractModel.model method, which takes in a data dictionary and makes calls to :func:numpyro.sample() <numpyro.primitives.sample> to define the stochastic generative process.

AbstractModel

Abstract base class for Pyter models

get_reparam

get_reparam()
Source code in pyter/models.py
350
351
352
def get_reparam(self):
    """ """
    return reparam(self.model, self.reparam_dict)

model

model(data: dict = None)

Parameters:

Name Type Description Default
data dict

(Default value = None)

None
Source code in pyter/models.py
336
337
338
339
340
341
342
343
344
345
346
347
348
def model(self, data: dict = None):
    """

    Parameters
    ----------
    data :
         (Default value = None)

    Returns
    -------

    """
    raise NotImplementedError()

validate_data

validate_data(data: AbstractData, run_data: dict)

Parameters:

Name Type Description Default
data AbstractData
required
run_data dict
required
Source code in pyter/models.py
354
355
356
357
358
359
360
361
362
363
364
365
366
def validate_data(self, data: pdata.AbstractData, run_data: dict):
    """

    Parameters
    ----------
    data :
    run_data :

    Returns
    -------

    """
    raise NotImplementedError()

HalfLifeModel

Bases: AbstractModel

Model to infer virus halflives from experimental timeseries data.

A timeseries here is any set titration results taken at different timepoints that represent repeat samples from the same viral stock. But we can also handle cases in which non-destructive sampling is impossible (for example, depositing stock onto a surface and retrieving it at :math:t = 0 h, :math:t=1 h, etc.). To do this, we use a hierarchical approach: we infer a shared halflife for the samples jointly with a and modal value for the initial titer deposited. Each individual sample's unknown :math:t = 0 value may vary about this value. This allows the model to use the immediately retrieved t = 0 titers to make inferences about the what the unmeasured :math:t = 0 h titers were for the samples taken at :math:t = 1 h, :math:t = 2 h, etc. samples.

model

model(data: dict | None = None) -> tuple[Array, Array]

Parameters:

Name Type Description Default
data :class:`dict`

Dictionary of data with which to fit the model. Defaults to :py:data:None, in which case an empty dictionary is used.

None

Returns:

Type Description
log_titer, wells : :class:`tuple`
( :class:`jax.Array`, :class:`jax.Array` )

Tuple of arrays containing sampled log titer values and sampled well statuses / plaque counts.

Source code in pyter/models.py
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
def model(self, data: dict | None = None) -> tuple[jax.Array, jax.Array]:
    """

    Parameters
    ----------
    data : :class:`dict`
        Dictionary of data with which to fit the model.
        Defaults to :py:data:`None`, in which case an
        empty dictionary is used.

    Returns
    -------
    log_titer, wells : :class:`tuple`
    ( :class:`jax.Array`, :class:`jax.Array` )
        Tuple of arrays containing sampled log
        titer values and sampled
        well statuses / plaque counts.
    """
    if data is None:
        data = {}

    well_titer_id = data["well_internal_id_values"]["titer"]
    titer_hl_id = data["titer_internal_id_values"]["halflife"]
    titer_intercept_id = data["titer_internal_id_values"]["intercept"]

    log_halflife = self.sample_log_halflife(data=data)
    log_titer_intercept = self.sample_log_titer_intercept(data=data)

    halflife = npro.deterministic("halflife", jnp.exp(log_halflife))
    decay_rate = npro.deterministic(
        "decay_rate", jnp.log(2) / (halflife * jnp.log(data["log_base"]))
    )
    initial_log_titer = npro.deterministic(
        "initial_log_titer", log_titer_intercept[titer_intercept_id]
    )
    predicted_log_titer = npro.deterministic(
        "predicted_log_titer",
        initial_log_titer
        - decay_rate[titer_hl_id] * data["titer_time"]
        + data["log_titer_change_other"],
    )

    log_titer = self.sample_log_titer(predicted_log_titer, data=data)

    wells = npro.sample(
        "well_status",
        well_distribution_factory(
            assay=self.assay,
            log_titer=log_titer[well_titer_id],
            log_dilution=data["well_dilution"],
            log_base=data["log_base"],
            well_volume=data["well_volume"],
            false_hit_rate=data["false_hit_rate"],
            validate_args=True,
        ),
        obs=data["well_status"],
    )

    return (log_titer, wells)

sample_log_halflife

sample_log_halflife(data: dict = None) -> Array

Sample log half-life values, either from a fixed-parameter prior or hierarchically, as specified for the user.

Parameters:

Name Type Description Default
data :class:`dict`

Dictionary of data with which to fit the model. Defaults to :py:data:None.

None

Returns:

Name Type Description
log_halflife :class:`jax.Array`

An array of sampled halflives.

Source code in pyter/models.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def sample_log_halflife(self, data: dict = None) -> jax.Array:
    """
    Sample log half-life values, either
    from a fixed-parameter prior or hierarchically,
    as specified for the user.

    Parameters
    ----------
    data : :class:`dict`
        Dictionary of data with which to fit the model.
        Defaults to :py:data:`None`.


    Returns
    -------
    log_halflife: :class:`jax.Array`
        An array of sampled halflives.

    """
    if self.halflives_hier:
        log_halflife = sample_loc_scale_hier(
            "log_halflife",
            data["n_values"]["halflife"],
            data["n_values"]["halflife_loc"],
            data["n_values"]["halflife_scale"],
            self.log_halflife_distribution,
            data["halflife_internal_id_values"]["loc"],
            data["halflife_internal_id_values"]["scale"],
            self.log_halflife_loc_prior,
            self.log_halflife_scale_prior,
        )
    else:
        log_halflife = sample_non_hier(
            "log_halflife",
            data["n_values"]["halflife"],
            self.log_halflife_distribution,
        )

    return log_halflife

sample_log_titer

sample_log_titer(predicted_titer: Array, data: dict = None) -> Array

Sample realized log titer values for the modeled titers, either deterministically predicted from the other parameters, or with an inferred degree of noise, as specified by the user.

Parameters:

Name Type Description Default
predicted_titer :class:`jax.Array`

An array of predicted titer values.

required
data :class:`dict`

Dictionary of data with which to fit the model. Defaults to :py:data:None.

None

Returns:

Name Type Description
log_titer :class:`jax.Array`
Source code in pyter/models.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def sample_log_titer(
    self, predicted_titer: jax.Array, data: dict = None
) -> jax.Array:
    """
    Sample realized log titer values for the modeled
    titers, either deterministically predicted from
    the other parameters, or with an inferred degree
    of noise, as specified by the user.

    Parameters
    ----------
    predicted_titer : :class:`jax.Array`
        An array of predicted titer values.

    data : :class:`dict`
        Dictionary of data with which to fit the model.
        Defaults to :py:data:`None`.

    Returns
    -------
    log_titer : :class:`jax.Array`

    """
    if self.titers_overdispersed:
        log_titer_error_scale = sample_non_hier(
            "log_titer_error_scale",
            data["n_values"]["titer_error_scale"],
            self.log_titer_error_scale_prior,
        )
        es_id = data["titer_internal_id_values"]["titer_error_scale"]
        log_titer = npro.sample(
            "log_titer",
            self.log_titer_error_distribution(
                loc=predicted_titer, scale=log_titer_error_scale[es_id]
            ),
        )
    else:
        log_titer = npro.deterministic("log_titer", predicted_titer)
    return log_titer

sample_log_titer_intercept

sample_log_titer_intercept(data: dict = None) -> Array

Sample log intercept (i.e. t = 0) values for the modeled titers, either a fixed-parameter prior or hierarchically, as specified for the user.

Parameters:

Name Type Description Default
data :class:`dict`

Dictionary of data with which to fit the model. Defaults to :data:None.

None

Returns:

Name Type Description
log_titer_intercept :class:`jax.Array`

An array of sampled intercepts.

Source code in pyter/models.py
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def sample_log_titer_intercept(self, data: dict = None) -> jax.Array:
    """
    Sample log intercept (i.e. t = 0) values for the
    modeled titers, either a fixed-parameter prior
    or hierarchically, as specified for the user.

    Parameters
    ----------
    data : :class:`dict`
        Dictionary of data with which to fit the model.
        Defaults to :data:`None`.

    Returns
    -------
    log_titer_intercept : :class:`jax.Array`
        An array of sampled intercepts.

    """
    if self.intercepts_hier:
        log_titer_intercept = sample_loc_scale_hier(
            "log_titer_intercept",
            data["n_values"]["intercept"],
            data["n_values"]["intercept_loc"],
            data["n_values"]["intercept_scale"],
            self.log_intercept_distribution,
            data["intercept_internal_id_values"]["loc"],
            data["intercept_internal_id_values"]["scale"],
            self.log_intercept_loc_prior,
            self.log_intercept_scale_prior,
        )
    else:
        log_titer_intercept = sample_non_hier(
            "log_titer_intercept",
            data["n_values"]["intercept"],
            self.log_intercept_distribution,
        )
    return log_titer_intercept

validate_data

validate_data(data: AbstractData, run_data: dict)

Parameters:

Name Type Description Default
data AbstractData
required
run_data dict
required
Source code in pyter/models.py
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
def validate_data(self, data: pdata.AbstractData, run_data: dict):
    """

    Parameters
    ----------
    data :
    run_data :

    Returns
    -------

    """
    if not isinstance(data, pdata.HalfLifeData):
        raise ValueError(
            "Incorrect data type {} for model {}".format(
                type(data), type(self)
            )
        )
    pass

MultiphaseHalfLifeModel

Bases: HalfLifeModel

model

model(data: dict = None)

Parameters:

Name Type Description Default
data dict

(Default value = None)

None
Source code in pyter/models.py
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
def model(self, data: dict = None):
    """

    Parameters
    ----------
    data :
         (Default value = None)

    Returns
    -------

    """

    log_halflife, break_times = self.sample_log_halflife(data=data)
    log_titer_intercept = self.sample_log_titer_intercept(data=data)

    halflife = npro.deterministic("halflife", jnp.exp(log_halflife))
    decay_rate = npro.deterministic(
        "decay_rate", jnp.log(2) / (halflife * jnp.log(data["log_base"]))
    )
    # breakpoint_deltas and break_times have
    # shape (n_phases - 1, n_halflives)
    # we add a first row of zeros
    start_times = jnp.vstack(
        [jnp.zeros_like(break_times[0, ::]), break_times]
    )

    well_titer_id = data["well_internal_id_values"]["titer"]
    if self.halflives_hier:
        titer_break_id = data["titer_internal_id_values"]["halflife_loc"]
    else:
        titer_break_id = data["titer_internal_id_values"]["halflife"]

    # titer_break_times has shape (n_phases - 1, n_titers)
    titer_break_times = break_times[::, titer_break_id]
    titer_start_times = start_times[::, titer_break_id]
    titer_decay_rates = decay_rate[::, titer_break_id]

    # add a titer_time row at the end
    # of titer_break_times
    possible_end_times = jnp.vstack(
        [titer_break_times, data["titer_time"]]
    )

    # cut off each phase at the
    # end of the phase or the
    # observation time, whichever
    # is smaller
    cutoff_end_times = jnp.where(
        possible_end_times < data["titer_time"],
        possible_end_times,
        data["titer_time"],
    )

    # how much time (possibly 0!)
    # did the sample actually
    # spend in each of the decay
    # phases
    phase_times = npro.deterministic(
        "phase_times",
        jnp.where(
            titer_start_times < cutoff_end_times,
            cutoff_end_times - titer_start_times,
            0,
        ),
    )

    total_decay = npro.deterministic(
        "total_decay", jnp.sum(phase_times * titer_decay_rates, axis=0)
    )

    # predicted titer has length
    # n_titers
    predicted_log_titer = npro.deterministic(
        "predicted_log_titer", log_titer_intercept - total_decay
    )

    log_titer = self.sample_log_titer(predicted_log_titer, data=data)

    wells = npro.sample(
        "well_status",
        well_distribution_factory(
            assay=self.assay,
            log_titer=log_titer[well_titer_id],
            log_dilution=data["well_dilution"],
            log_base=data["log_base"],
            well_volume=data["well_volume"],
            false_hit_rate=data["false_hit_rate"],
            validate_args=True,
        ),
        obs=data["well_status"],
    )

    return (log_titer, wells)

sample_log_halflife

sample_log_halflife(data: dict = None)

Parameters:

Name Type Description Default
data dict

(Default value = None)

None
Source code in pyter/models.py
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
def sample_log_halflife(self, data: dict = None):
    """

    Parameters
    ----------
    data :
         (Default value = None)

    Returns
    -------

    """
    if self.halflives_hier:
        log_halflife_first = sample_loc_scale_hier(
            "log_halflife_first",
            data["n_values"]["halflife"],
            data["n_values"]["halflife_loc"],
            data["n_values"]["halflife_scale"],
            self.log_halflife_distribution,
            data["halflife_internal_id_values"]["loc"],
            data["halflife_internal_id_values"]["scale"],
            self.log_halflife_loc_prior,
            self.log_halflife_scale_prior,
        )
        n_offsets = data["n_values"]["halflife_loc"]
    else:
        log_halflife_first = sample_non_hier(
            "log_halflife_first",
            data["n_values"]["halflife"],
            self.log_halflife_distribution,
        )
        n_offsets = data["n_values"]["halflife"]

    with npro.plate("offsets", n_offsets):
        with npro.plate("phases", self.n_phases - 1):
            breakpoint_deltas = npro.sample(
                "breakpoint_deltas", self.breakpoint_delta_prior
            )
            break_times = npro.deterministic(
                "breakpoint_times",
                # cumsum columns to get break times
                # for each experiment
                jnp.cumsum(breakpoint_deltas, axis=0),
            )

            log_halflife_offsets = npro.sample(
                "log_halflife_offsets", self.log_halflife_offset_prior
            )
            pass
        pass

    offset_stack = jnp.vstack(
        [jnp.zeros((self.n_phases - 1, n_offsets)), log_halflife_offsets]
    )

    if self.halflives_hier:
        expand_ids = data["halflife_internal_id_values"]["loc"]
        offset_stack = offset_stack[::, expand_ids]

    log_halflife = npro.deterministic(
        "log_halflife", log_halflife_first + offset_stack
    )

    return log_halflife, break_times

TiterModel

Bases: AbstractModel

Model to infer individual titers independently

model

model(data: dict = None)

Parameters:

Name Type Description Default
data dict
None
Source code in pyter/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def model(self, data: dict = None):
    """

    Parameters
    ----------
    data :

    Returns
    -------

    """

    log_titer = sample_non_hier(
        "log_titer", data["n_values"]["titer"], self.log_titer_prior
    )

    wells = npro.sample(
        "well_status",
        well_distribution_factory(
            assay=self.assay,
            log_titer=log_titer[data["well_internal_id_values"]["titer"]],
            log_dilution=data["well_dilution"],
            log_base=data["log_base"],
            well_volume=data["well_volume"],
            false_hit_rate=data["false_hit_rate"],
            validate_args=True,
        ),
        obs=data["well_status"],
    )

    return wells

validate_data

validate_data(data: AbstractData, run_data: dict)

Parameters:

Name Type Description Default
data :class:`~pyter.data.TiterData` :

Pyter data object to validate.

required
run_data :class:`dict` :

Frozen dictionary of data with which to fit the model, generated from a :class:TiterData object by the :meth:~pyter.data.TiterData.freeze method.

required

Returns:

Type Description
py:data:`True`
Source code in pyter/models.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def validate_data(self, data: pdata.AbstractData, run_data: dict):
    """

    Parameters
    ----------
    data : :class:`~pyter.data.TiterData` :
        Pyter data object to validate.

    run_data : :class:`dict` :
        Frozen dictionary of data with which
        to fit the model, generated from
        a :class:`TiterData` object
        by the :meth:`~pyter.data.TiterData.freeze`
        method.

    Returns
    -------
    :py:data:`True`

    Raises
    ------

    """
    if not isinstance(data, pdata.TiterData):
        raise ValueError(
            "Incorrect data type {} for model {}".format(
                type(data), type(self)
            )
        )
    return True

loc_scale_factory

loc_scale_factory(distribution: str, loc: ArrayLike = None, scale: ArrayLike = None) -> Distribution

Factory function for distributions with a loc/scale parameterization

Parameters:

Name Type Description Default
distribution :class:`str`

the name of the desired distribution

required
loc :data:`~numpy.typing.ArrayLike`

the location parameter(s) of the desired distribution

None
scale :data:`~numpy.typing.ArrayLike`

the scale parameter(s) of the desired distribution

None

Returns:

Name Type Description
dist :class:`~numpyro.distributions.distribution.Distribution`

The parameterized distribution.

Source code in pyter/models.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def loc_scale_factory(
    distribution: str, loc: ArrayLike = None, scale: ArrayLike = None
) -> dist.Distribution:
    """Factory function for distributions
    with a loc/scale parameterization

    Parameters
    ----------
    distribution : :class:`str`
        the name of the desired distribution
    loc :  :data:`~numpy.typing.ArrayLike`
        the location parameter(s) of the desired distribution
    scale :  :data:`~numpy.typing.ArrayLike`
        the scale parameter(s) of the desired distribution

    Returns
    -------
    dist : :class:`~numpyro.distributions.distribution.Distribution`
        The parameterized distribution.

    """
    distributions = {
        "normal": dist.Normal,
        "cauchy": dist.Cauchy,
        "studentt": dist.StudentT,
    }
    dist_pick = distributions.get(distribution, None)
    if dist_pick is None:
        raise ValueError(
            "Unknown or unsupported "
            "distribution {}.\n\n"
            "Supported distributions: {}"
            "".format(distribution, [key for key in distributions.keys()])
        )
    return dist_pick(loc=loc, scale=scale)

sample_loc_scale_hier

sample_loc_scale_hier(param_name: str, param_dim: int, n_locs: int, n_scales: int, param_distribution: Distribution, loc_ids: ArrayLike, scale_ids: ArrayLike, loc_prior: Distribution, scale_prior: Distribution) -> Array

Sample a vector of hierarchical inferred parameters alongside their inferred parent parameters.

Convenience wrapper to sample a vectorized parameter in which individual values are "loc/scale" hierarchical. That is, parameter values have a distribution that is determined by two parameters--- a location parameter (loc, e.g. the mean/median/mode of a :class:~numpyro.distributions.continuous.Normal distribution) scale parameter (scale, e.g. the standard deviation of a :class:~numpyro.distributions.continuous.Normal distribution)---but the values of the location and/or the scale parameter are unknown and inferred alongside the child parameters.

Parameters:

Name Type Description Default
param_name :class:`str` :

The name of the parameter.

required
param_dim :class:`int` :

The length of the parameter vector to sample.

required
n_locs :class:`int` :

The number of groups of loc (e.g. mean, mode) values across all the parameters in the vector, e.g. 3 groups of parameters where group members are Normally distributed about unknown means :math:\mu_1, :math:\mu_2, and :math:\mu_3, respectively.

required
n_scales :class:`int` :

The number of groups of scale (e.g. standard deviation) values across all the parameters in the vector, e.g. 3 groups of parameters whose members are Normally distributed about their (possibly shared, see n_locs) unknown means with unknown shared standard deviations :math:\sigma_1, :math:\sigma_2, and :math:\sigma_3 respectively.

required
param_distribution :class:`~numpyro.distributions.
required
distribution

A loc / scale parameterizable probability distribution.

required
loc_ids :data:`~numpy.typing.ArrayLike`

Array of ids associating each parameter in the desired vector to one of the n_locs location parameters to be inferred.

required
scale_ids :data:`~numpy.typing.ArrayLike`

Array of ids associating each parameter in the desired vector to one of the n_scales scale parameters to be inferred.

required
loc_prior :class:`~numpyro.distributions.distribution.Distribution`

Prior distribution for the inferred unknown loc parameters.

required
scale_prior :class:`~numpyro.distributions.distribution.Distribution`

Prior distribution for the inferred unknown scale parameters.

required

Returns:

Name Type Description
param :class:`jax.Array`

A sampled vector of parameters.

Source code in pyter/models.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def sample_loc_scale_hier(
    param_name: str,
    param_dim: int,
    n_locs: int,
    n_scales: int,
    param_distribution: dist.Distribution,
    loc_ids: ArrayLike,
    scale_ids: ArrayLike,
    loc_prior: dist.Distribution,
    scale_prior: dist.Distribution,
) -> jax.Array:
    """
    Sample a vector of hierarchical inferred
    parameters alongside their inferred
    parent parameters.

    Convenience wrapper to sample
    a vectorized parameter in which individual
    values are "loc/scale" hierarchical. That is,
    parameter values have a distribution
    that is determined by two parameters---
    a location parameter
    (``loc``, e.g. the mean/median/mode of a
    :class:`~numpyro.distributions.continuous.Normal`
    distribution) scale parameter
    (``scale``, e.g. the standard
    deviation of a
    :class:`~numpyro.distributions.continuous.Normal`
    distribution)---but the values of the
    location and/or the scale parameter
    are unknown and inferred alongside the
    child parameters.

    Parameters
    ----------
    param_name : :class:`str` :
        The name of the parameter.

    param_dim : :class:`int` :
        The length of the parameter vector to sample.

    n_locs : :class:`int` :
        The number of groups of ``loc``
        (e.g. mean, mode) values across all the
        parameters in the vector, e.g. 3 groups
        of parameters where group members
        are Normally distributed about unknown
        means :math:`\\mu_1`, :math:`\\mu_2`,
        and :math:`\\mu_3`, respectively.

    n_scales : :class:`int` :
        The number of groups of ``scale``
        (e.g. standard deviation) values
        across all the parameters in the
        vector, e.g. 3 groups of parameters
        whose members are Normally distributed
        about their (possibly shared,
        see ``n_locs``) unknown means
        with unknown shared standard deviations
        :math:`\\sigma_1`, :math:`\\sigma_2`,
        and :math:`\\sigma_3` respectively.

    param_distribution : :class:`~numpyro.distributions.
    distribution.Distribution`
        A loc / scale parameterizable probability distribution.

    loc_ids : :data:`~numpy.typing.ArrayLike`
        Array of ids associating each parameter in the
        desired vector to one of the ``n_locs``
        location parameters to be inferred.

    scale_ids : :data:`~numpy.typing.ArrayLike`
        Array of ids associating each parameter in the
        desired vector to one of the ``n_scales``
        scale parameters to be inferred.

    loc_prior : :class:`~numpyro.distributions.distribution.Distribution`
        Prior distribution for the inferred unknown
        ``loc`` parameters.

    scale_prior : :class:`~numpyro.distributions.distribution.Distribution`
        Prior distribution for the inferred unknown
        ``scale`` parameters.


    Returns
    -------
    param : :class:`jax.Array`
        A sampled vector of parameters.

    """
    param_loc = npro.sample(param_name + "_loc", loc_prior.expand((n_locs,)))
    param_scale = npro.sample(
        param_name + "_scale", scale_prior.expand((n_scales,))
    )

    param = npro.sample(
        param_name,
        param_distribution(
            loc=param_loc[loc_ids], scale=param_scale[scale_ids]
        ),
    )

    return param

sample_non_hier

sample_non_hier(param_name: str, param_dim: int, param_prior: Distribution) -> Array

Sample a vector of inferred parameters whose prior is fixed

Convenience wrapper for :func:numpyro.sample() <numpyro.primitives.sample> to sample a vectorized parameter that is non-hierarchical.

Parameters:

Name Type Description Default
param_name :class:`str` :

The name of the parameter

required
param_dim :class:`int` :

The length of the parameter vector

required
param_prior :class:`~numpyro.distributions.distribution.Distribution`

A prior distribution for the parameter

required

Returns:

Name Type Description
param :class:`jax.Array`:

The sampled parameter vector.

Source code in pyter/models.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def sample_non_hier(
    param_name: str, param_dim: int, param_prior: dist.Distribution
) -> jax.Array:
    """
    Sample a vector of inferred
    parameters whose prior is fixed

    Convenience wrapper for :func:`numpyro.sample()
    <numpyro.primitives.sample>` to sample
    a vectorized parameter that is non-hierarchical.

    Parameters
    ----------
    param_name : :class:`str` :
         The name of the parameter

    param_dim : :class:`int` :
         The length of the parameter vector

    param_prior : :class:`~numpyro.distributions.distribution.Distribution`
         A prior distribution for the parameter

    Returns
    -------
    param : :class:`jax.Array`:
         The sampled parameter vector.

    """
    param = npro.sample(param_name, param_prior.expand((param_dim,)))
    return param

well_distribution_factory

well_distribution_factory(assay: str, log_titer: ArrayLike, log_dilution: ArrayLike, log_base: ArrayLike, well_volume: ArrayLike, false_hit_rate: ArrayLike, validate_args: bool = True) -> TiterPlate

Get an appropriate distribution for titer wells.

Each entry of the various array inputs represents exactly one titration well.

Parameters:

Name Type Description Default
assay :class:`str` = {'pfu', 'tcid'}

Which titration assay to use. Options are 'pfu'--plaque assay--and 'tcid'--endpoint titration assay.

required
log_titer :data:`~numpy.typing.ArrayLike`

Underlying log titer(s) per unit volume in the undilute sample(s).

required
log_dilution :data:`~numpy.typing.ArrayLike`

Log dilution(s) relative to the original sample(s) for each well's inoculum.

required
log_base ArrayLike

Base of the logarithim for logarithmic quantities including titer and dilution (e.g. e, 2, 10, etc).

required
well_volume :data:`~numpy.typing.ArrayLike`

Volume of the inoculum delivered to each well, in the same units as the per unit volume for the log_titer values. So if log titers are given per mL, this is the volume of inoculum in mL.

required
false_hit_rate :data:`~numpy.typing.ArrayLike`

Rate (mean number per well) of false hits (i.e. rate of apparent infection with a sample containing no infectious material).

required
validate_args :class:`bool`

Passed to the :class:~numpyro.distributions.distribution.Distribution constructor to enable / disable parameter validation. Default :py:data:True.

True

Returns:

Name Type Description
dist :class:`~pyter.distributions.TiterPlate`

A :class:~numpyro.distributions.distribution.Distribution object representing the distribution of the well plaque counts (plaque assay) or positive / negative statuses (endpoint titration assay).

Source code in pyter/models.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def well_distribution_factory(
    assay: str,
    log_titer: ArrayLike,
    log_dilution: ArrayLike,
    log_base: ArrayLike,
    well_volume: ArrayLike,
    false_hit_rate: ArrayLike,
    validate_args: bool = True,
) -> TiterPlate:
    """
    Get an appropriate distribution for titer wells.

    Each entry of the various array inputs represents
    exactly one titration well.

    Parameters
    ----------
    assay : :class:`str` = {'pfu', 'tcid'}
        Which titration assay to use. Options are
        ``'pfu'``--plaque assay--and
        ``'tcid'``--endpoint titration assay.

    log_titer : :data:`~numpy.typing.ArrayLike`
        Underlying log titer(s) per unit volume in
        the undilute sample(s).

    log_dilution : :data:`~numpy.typing.ArrayLike`
        Log dilution(s) relative to the original
        sample(s) for each well's inoculum.

    log_base ~numpy.typing.ArrayLike`
        Base of the logarithim for logarithmic
        quantities including titer and
        dilution (e.g. e, 2, 10, etc).

    well_volume : :data:`~numpy.typing.ArrayLike`
        Volume of the inoculum delivered to
        each well, in the same units as the
        per unit volume for the ``log_titer``
        values. So if log titers are given per
        mL, this is the volume of inoculum
        in mL.

    false_hit_rate : :data:`~numpy.typing.ArrayLike`
        Rate (mean number per well)
        of false hits (i.e. rate of apparent
        infection with a sample containing
        no infectious material).

    validate_args : :class:`bool`
        Passed to the
        :class:`~numpyro.distributions.distribution.Distribution`
        constructor to enable / disable
        parameter validation.
        Default :py:data:`True`.


    Returns
    -------
    dist : :class:`~pyter.distributions.TiterPlate`
        A :class:`~numpyro.distributions.distribution.Distribution`
        object representing the distribution of the
        well plaque counts (plaque assay) or
        positive / negative statuses
        (endpoint titration assay).

    """

    assays = {"tcid": EndpointTiterPlate, "pfu": PlaquePlate}

    if assay is None:
        raise ValueError("Must specify assay")

    distribution = assays.get(assay, None)

    if distribution is None:
        raise ValueError(
            "Unknown or unsupported "
            "assay {}.\n\n"
            "Supported assays are endpoint titration "
            "(set assay = 'tcid') and "
            "plaque assay (set assay = 'pfu')"
            "".format(assay)
        )
    return distribution(
        log_titer=log_titer,
        log_dilution=log_dilution,
        log_base=log_base,
        well_volume=well_volume,
        false_hit_rate=false_hit_rate,
        validate_args=validate_args,
    )

Data

AbstractData

Abstract base class for holding data associated to Pyter inferential models.

freeze

freeze()

Validate, fix, and format data for use in inference.

Data is returned as a :class:dict that can be passed to a corresponding :class:Model <AbstractModel> instance.

The actual logic of validation and data preparation is handled by sub-class specific :meth:validate and :meth:_freeze methods; the common :meth:freeze method ensures common data dictionary output formatting across all :class:Data <AbstractData> subclasses.

Returns:

Name Type Description
data_dict :class:`dict`

A dictionary of data to pass to a model.

Source code in pyter/data.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def freeze(self):
    """
    Validate, fix, and format data
    for use in inference.

    Data is returned as a :class:`dict`
    that can be passed to a
    corresponding :class:`Model
    <AbstractModel>` instance.

    The actual logic of validation
    and data preparation is handled
    by sub-class specific :meth:`validate`
    and :meth:`_freeze` methods; the
    common :meth:`freeze` method ensures
    common data dictionary output
    formatting across all
    :class:`Data <AbstractData>`
    subclasses.

    Returns
    -------
    data_dict : :class:`dict`
        A dictionary of data to pass to a model.

    """
    self.validate()
    self._freeze()
    return {
        s: getattr(self, s) for s in self.__slots__ if hasattr(self, s)
    }

validate

validate()
Source code in pyter/data.py
187
188
189
190
191
def validate(self):
    """ """
    raise NotImplementedError(
        "Abstract class AbstractDatahas no validate() method"
    )

HalfLifeData

Bases: AbstractData

Data struct for inferring half-life of infectious virus

index_prior_parameters

index_prior_parameters()

Assign prior parameters to appropriate indices

Parameters:

Name Type Description Default
Returns
required
Source code in pyter/data.py
369
370
371
372
373
374
375
376
377
378
379
380
381
def index_prior_parameters(self):
    """
    Assign prior parameters to
    appropriate indices

    Parameters
    ----------

    Returns
    -------

    """
    pass

update_internal_ids

update_internal_ids()

Assign internal ids for parameters

Source code in pyter/data.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def update_internal_ids(self):
    """
    Assign internal ids
    for parameters
    """

    for param in [
        "titer",
        "titer_error_scale",
        "halflife",
        "halflife_loc",
        "halflife_scale",
        "intercept",
        "intercept_loc",
        "intercept_scale",
    ]:
        (
            self.well_internal_id_values[param],
            self.unique_internal_ids[param],
            self.unique_external_ids[param],
            self.id_representative_rows[param],
            self.n_values[param],
        ) = to_internal_ids(self.__getattribute__("well_" + param + "_id"))

    for param in [
        "halflife",
        "intercept",
        "titer_error_scale",
        "halflife_loc",
    ]:
        self.titer_internal_id_values[param] = get_associated_internal_ids(
            "titer",
            param,
            self.well_internal_id_values,
            self.id_representative_rows,
        )

    for param in ["loc", "scale"]:
        self.halflife_internal_id_values[param] = (
            get_associated_internal_ids(
                "halflife",
                "halflife_" + param,
                self.well_internal_id_values,
                self.id_representative_rows,
            )
        )
        self.intercept_internal_id_values[param] = (
            get_associated_internal_ids(
                "intercept",
                "intercept_" + param,
                self.well_internal_id_values,
                self.id_representative_rows,
            )
        )

    self.titer_time = self.well_time[self.id_representative_rows["titer"]]
    self.log_titer_change_other = np.broadcast_to(
        self.log_well_change_other,
        self.well_internal_id_values["titer"].shape,
    )[self.id_representative_rows["titer"]]

validate

validate() -> bool

Null data is necessarily valid

Returns:

Type Description
data:`True`
Source code in pyter/data.py
401
402
403
404
405
406
407
408
409
def validate(self) -> bool:
    """
    Null data is necessarily valid

    Returns
    -------
    :data:`True`
    """
    return True

NullData

Bases: AbstractData

:class:Data <AbstractData> class for models that do not take any user-provided data, and for testing.

validate

validate() -> bool

Null data is necessarily valid

Returns:

Type Description
data:`True`
Source code in pyter/data.py
208
209
210
211
212
213
214
215
216
def validate(self) -> bool:
    """
    Null data is necessarily valid

    Returns
    -------
    :data:`True`
    """
    return True

TiterData

Bases: AbstractData

:class:Data <AbstractData> class for inference of individual titers.

validate

validate() -> bool

Null data is necessarily valid

Returns:

Type Description
data:`True`
Source code in pyter/data.py
260
261
262
263
264
265
266
267
268
def validate(self) -> bool:
    """
    Null data is necessarily valid

    Returns
    -------
    :data:`True`
    """
    return True

get_associated_internal_ids

get_associated_internal_ids(key_param: str, value_param: str, internal_id_dict: dict, representative_row_dict: dict) -> ndarray

For example, get internal intercept id for each internal titer id

Parameters:

Name Type Description Default
key_param str
required
value_param str
required
internal_id_dict dict
required
representative_row_dict dict
required
Source code in pyter/data.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def get_associated_internal_ids(
    key_param: str,
    value_param: str,
    internal_id_dict: dict,
    representative_row_dict: dict,
) -> np.ndarray:
    """For example, get internal intercept id
    for each internal titer id

    Parameters
    ----------
    key_param :
    value_param:
    internal_id_dict :
    representative_row_dict:

    Returns
    -------

    """
    rows = representative_row_dict[key_param]
    if rows.size > 0 and internal_id_dict[value_param].size > 0:
        return internal_id_dict[value_param][rows]
    else:
        return np.array([])

to_internal_ids

to_internal_ids(external_ids: ArrayLike) -> tuple[ndarray, ndarray, ndarray, int]

Internally index a long tidy data frame.

Parameters:

Name Type Description Default
external_ids :data:`~numpy.typing.ArrayLike`

Array of external ids, which may be strings, numeric values, or another type coercible to a :class:numpy.array.

required

Returns:

Name Type Description
result :class:`tuple`

A tuple containing:

internal_ids : :class:numpy.ndarray An array of the assigned internal id values for each entry ("row") of the provided external ids.

unique_internal_ids : :class:numpy.ndarray An array of the unique values of the internal ids

unique_external_ids : :class:numpy.ndarray An array of the unique values of the provided external ids

representative_rows : :class:numpy.ndarray An array of the indices for the external_ids and internal_ids arrays that return one instance for each unique id value. This can then be used to index any other columns of the original data table from which the external_ids came.

Source code in pyter/data.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def to_internal_ids(
    external_ids: ArrayLike,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
    """
    Internally index a long tidy
    data frame.

    Parameters
    ----------
    external_ids : :data:`~numpy.typing.ArrayLike`
        Array of external ids, which may be strings, numeric
        values, or another type coercible to a
        :class:`numpy.array`.


    Returns
    -------
    result : :class:`tuple`

        A tuple containing:


        **internal_ids** : :class:`numpy.ndarray`
            An array of the assigned internal id values for each entry
            ("row") of the provided external ids.

        **unique_internal_ids** : :class:`numpy.ndarray`
            An array of the unique values of the internal ids

        **unique_external_ids** : :class:`numpy.ndarray`
            An array of the unique values of the provided external ids

        **representative_rows** : :class:`numpy.ndarray`
            An array of the indices for the ``external_ids`` and
            ``internal_ids`` arrays that return one instance
            for each unique id value. This can then be used
            to index any other columns of the original data
            table from which the ``external_ids`` came.


    """
    (unique_external_ids, internal_ids) = np.unique(
        external_ids, return_inverse=True
    )

    (unique_internal_ids, representative_rows) = np.unique(
        internal_ids, return_index=True
    )

    n_values = unique_internal_ids.size

    return (
        internal_ids,
        unique_internal_ids,
        unique_external_ids,
        representative_rows,
        n_values,
    )

validate_internal_ids

validate_internal_ids(internal_ids: ArrayLike, unique_internal_ids: ArrayLike, unique_external_ids: ArrayLike, representative_rows: ArrayLike, n_values: int) -> None

Parameters:

Name Type Description Default
internal_ids ArrayLike
required
unique_internal_ids ArrayLike
required
unique_external_ids ArrayLike

param representative_rows:

required
n_values int
required
unique_internal_ids ArrayLike
required
representative_rows ArrayLike
required
Source code in pyter/data.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def validate_internal_ids(
    internal_ids: ArrayLike,
    unique_internal_ids: ArrayLike,
    unique_external_ids: ArrayLike,
    representative_rows: ArrayLike,
    n_values: int,
) -> None:
    """

    Parameters
    ----------
    internal_ids :
    unique_internal_ids:
    unique_external_ids :
        param representative_rows:
    n_values :

    unique_internal_ids :

    representative_rows :


    Returns
    -------

    """

    if not all(
        [
            representative_rows.size == n_values,
            unique_external_ids.size == n_values,
        ]
    ):
        raise ValueError("Inconsistent numbers of ids assigned")
    if n_values > 0:
        if not max(internal_ids) == n_values - 1:
            raise ValueError("Missing internal ids")
        if not np.all(np.sort(unique_internal_ids) == np.arange(n_values)):
            raise ValueError("Missing internal_ids")
        pass
    pass

Distributions

Custom probability distributions for quantitative virology.

Distributions are subclasses of :class:numpyro.distributions.Distribution <numpyro.distributions.distribution.Distribution>, which we will refer to simply as class :class:~numpyro.distributions.distribution.Distribution.

EndpointTiterPlate

EndpointTiterPlate(**kwargs)

Bases: TiterPlate

Distribution class to represent a set of titers quantified by endpoint titration.

Source code in pyter/distributions.py
200
201
202
203
204
205
206
207
208
209
210
211
212
def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self.single_hit_rate = (
        self.false_hit_rate
        + self.well_volume
        * jnp.log(2)
        * jnp.exp(  # convert id50 to hit units
            jnp.log(self.log_base) * (self.log_titer + self.log_dilution)
        )
    )

    self.single_hit_ = PoissonSingleHit(rate=self.single_hit_rate)

log_prob

log_prob(value)

Parameters:

Name Type Description Default
value
required
Source code in pyter/distributions.py
231
232
233
234
235
236
237
238
239
240
241
242
243
@validate_sample
def log_prob(self, value):
    """

    Parameters
    ----------
    value :

    Returns
    -------

    """
    return self.single_hit_.log_prob(value)

sample

sample(key, sample_shape=())

Parameters:

Name Type Description Default
key
required
sample_shape

(Default value = ())

()
Source code in pyter/distributions.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def sample(self, key, sample_shape=()):
    """

    Parameters
    ----------
    key :

    sample_shape :
         (Default value = ())

    Returns
    -------

    """
    assert is_prng_key(key)
    return self.single_hit_.sample(key, sample_shape=sample_shape)

PlaquePlate

PlaquePlate(**kwargs)

Bases: TiterPlate

Distribution class to represent a set of titers quantified by a plaque assay.

Source code in pyter/distributions.py
150
151
152
153
154
155
156
157
def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self.hit_rate = self.false_hit_rate + self.well_volume * jnp.exp(
        jnp.log(self.log_base) * (self.log_titer + self.log_dilution)
    )

    self.poisson_ = npro.distributions.Poisson(rate=self.hit_rate)

log_prob

log_prob(value)

Parameters:

Name Type Description Default
value
required
Source code in pyter/distributions.py
176
177
178
179
180
181
182
183
184
185
186
187
188
@validate_sample
def log_prob(self, value):
    """

    Parameters
    ----------
    value :

    Returns
    -------

    """
    return self.poisson_.log_prob(value)

sample

sample(key, sample_shape=())

Parameters:

Name Type Description Default
key
required
sample_shape

(Default value = ())

()
Source code in pyter/distributions.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def sample(self, key, sample_shape=()):
    """

    Parameters
    ----------
    key :

    sample_shape :
         (Default value = ())

    Returns
    -------

    """
    assert is_prng_key(key)
    return self.poisson_.sample(key, sample_shape=sample_shape)

PoissonSingleHit

PoissonSingleHit(rate=0, validate_args=None)

Bases: Distribution

Poisson Single-Hit Distribution

This is a distribution that yields a 1 if a Poisson random variable is non-zero and a zero otherwise. It occurs in virology because if we expose a set of cells to some quantity of infectious virus particles ("virions"), the number that succesfully enter a cell and replicate can be modeled as a Poisson distributed random variable with a mean related to the initial quantity of virions. The probability of seeing any evidence of cell invasion is then equal to the probability that this Poisson random variable is non-zero (i.e. at least one virion successfully invaded a cell).

Parameters:

Name Type Description Default
rate :py:class:`float`

The rate of the Poisson random variable.

0
Source code in pyter/distributions.py
53
54
55
56
57
58
59
60
61
def __init__(self, rate=0, validate_args=None):
    self.rate = rate
    batch_shape = jnp.shape(self.rate)

    self.bernoulli_ = npro.distributions.Bernoulli(
        probs=1 - jnp.exp(-self.rate), validate_args=True
    )

    super().__init__(batch_shape=batch_shape, validate_args=validate_args)

log_prob

log_prob(value)

Parameters:

Name Type Description Default
value
required
Source code in pyter/distributions.py
78
79
80
81
82
83
84
85
86
87
88
89
90
@validate_sample
def log_prob(self, value):
    """

    Parameters
    ----------
    value :

    Returns
    -------

    """
    return self.bernoulli_.log_prob(value)

sample

sample(key, sample_shape=())

Parameters:

Name Type Description Default
key
required
sample_shape

(Default value = ())

()
Source code in pyter/distributions.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def sample(self, key, sample_shape=()):
    """
    Parameters
    ----------
    key :
    sample_shape :
         (Default value = ())

    Returns
    -------

    """
    assert is_prng_key(key)
    return self.bernoulli_.sample(key, sample_shape=sample_shape)

TiterPlate

TiterPlate(log_titer=None, log_dilution=None, log_base=10, well_volume=1, false_hit_rate=0, validate_args=None)

Bases: Distribution

Base distribution to represent a set of titers

Subclasses represent different assays: :class:PlaquePlate for plaque assays, and :class:EndpointTiterPlate for endpoint titration assays.

Source code in pyter/distributions.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    log_titer=None,
    log_dilution=None,
    log_base=10,
    well_volume=1,
    false_hit_rate=0,
    validate_args=None,
):
    (
        self.log_titer,
        self.log_dilution,
        self.log_base,
        self.well_volume,
        self.false_hit_rate,
    ) = promote_shapes(
        log_titer, log_dilution, log_base, well_volume, false_hit_rate
    )

    batch_shape = lax.broadcast_shapes(
        jnp.shape(self.log_titer),
        jnp.shape(self.log_dilution),
        jnp.shape(self.log_base),
        jnp.shape(self.well_volume),
        jnp.shape(self.false_hit_rate),
    )

    super().__init__(validate_args=validate_args, batch_shape=batch_shape)

Inference

Inference

infer

infer(model: AbstractModel = None, data: AbstractData = None, random_seed: int = None, num_warmup: int = 1000, num_samples: int = 1000, validate_data: bool = True, **kwargs)

Conduct inference.

Draw posterior samples from the given model with the given data

Parameters:

Name Type Description Default
model AbstractModel

(Default value = None)

None
data AbstractData

(Default value = None)

None
random_seed int

(Default value = None)

None
num_warmup int

(Default value = 1000)

1000
num_samples int

(Default value = 1000)

1000
validate_data bool

(Default value = True)

True
**kwargs
{}
Source code in pyter/infer.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def infer(
    self,
    model: AbstractModel = None,
    data: AbstractData = None,
    random_seed: int = None,
    num_warmup: int = 1000,
    num_samples: int = 1000,
    validate_data: bool = True,
    **kwargs,
):
    """Conduct inference.

    Draw posterior samples from the
    given model with the given
    data

    Parameters
    ----------
    model: AbstractModel :
         (Default value = None)
    data: AbstractData :
         (Default value = None)
    random_seed: int :
         (Default value = None)
    num_warmup: int :
         (Default value = 1000)
    num_samples: int :
         (Default value = 1000)
    validate_data: bool :
         (Default value = True)
    **kwargs :


    Returns
    -------

    """
    self.run_model = model
    self.kernel = self.new_kernel(self.run_model.get_reparam())
    self.mcmc_runner = self.new_runner(
        self.kernel, num_warmup, num_samples, **kwargs
    )

    if random_seed is None:
        random_seed = np.random.randint(0, 100000)

    # this saves state so that we know
    # exactly what we used for the run
    self.run_rng_key = jax.random.PRNGKey(random_seed)
    self.run_data = data.freeze()

    if validate_data:
        self.run_model.validate_data(data, self.run_data)

    self.mcmc_runner.run(self.run_rng_key, data=self.run_data)

new_kernel

new_kernel(model)

Parameters:

Name Type Description Default
model
required
Source code in pyter/infer.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def new_kernel(self, model):
    """

    Parameters
    ----------
    model :

    Returns
    -------

    """
    return NUTS(
        model,
        target_accept_prob=self.target_accept_prob,
        max_tree_depth=self.max_tree_depth,
        forward_mode_differentiation=self.forward_mode_differentiation,
    )

new_runner

new_runner(kernel, num_warmup, num_samples, **kwargs)

Parameters:

Name Type Description Default
kernel
required
num_warmup
required
num_samples
required
**kwargs
{}
Source code in pyter/infer.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def new_runner(self, kernel, num_warmup, num_samples, **kwargs):
    """

    Parameters
    ----------
    kernel :
    num_warmup :
    num_samples :
    **kwargs :

    Returns
    -------

    """
    return MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, **kwargs
    )

Constraints